Усиление обучения в двунаправленном РНН - PullRequest
0 голосов
/ 28 февраля 2020

Я какое-то время изучал глубокую генеративную нейронную сеть. Я в порядке с основами, но мне действительно нужно некоторое руководство и быстрый старт.

Недавно я наткнулся на эту статью «Двунаправленная генерация молекул с рекуррентными нейронными сетями» https://pubs.acs.org/doi/10.1021/acs.jcim.9b00943 https://github.com/ETHmodlab/BIMODAL

Я пытаюсь добавить код для применения обучения с подкреплением в сети BIMODAL. Автор предоставляет предварительно обученные модели, из которых я могу отобрать определенное количество строк SMILES.

Я в основном пытаюсь сделать что-то похожее на опубликованную сеть REINVENT https://arxiv.org/abs/1704.07555, Olivercrona et , al., 2017 https://github.com/MarcusOlivecrona/REINVENT

, но я не уверен, где начать определять и обновлять правильную функцию потерь на основе функции вознаграждения / оценки генерируемых структур (например, максимизации оштрафован logP и QED и др. c.). Я знаю свою функцию вознаграждения. Скажем, это QED (метри c для молекулы в диапазоне от 0 до 1).

На данный момент я просто определяю свою потерю как потерю MSE: (Я сомневаюсь, что это правильный подход хотя, но я не уверен, что еще я могу сделать. Я даже не уверен, что должен использовать потерю MSE. Может быть, мне нужно рассмотреть что-то вроде логитов и вероятностей)

Я переподготовил модель для 200 Шаги, но без улучшения счета / награды.

Кто-нибудь может дать мне несколько советов, пожалуйста? Я знаю, что это длинный вопрос, но большое спасибо заранее.

Это мой код, чтобы попытаться переобучить предварительно обученную модель с обучением с подкреплением

class Trainer_rl():

    def __init__(self):

        self._encoder = SMILESEncoder()
        self._model_type = 'BIMODAL'
        self._model = BIMODAL(molecule_size=151, encoding_dim=55,
                          lr=0.001, hidden_units=128)
        self._start_model = "../evaluation/BIMODAL_random_512_FineTuning_template/pretrained_model"
        self._starting_token = self._encoder.encode('G')
        self._T = 0.7

    def score(self, mol):
        return Chem.QED.default(mol)

    def loss(self, values):
        ones = torch.ones(len(values), dtype=torch.float32)
        values = torch.tensor(values, requires_grad=True, dtype=torch.float32)
        # The maximum of the reward function is one so I use the torch.ones tensor to calculate my MSELoss
        loss = nn.MSELoss()
        loss = loss(values, ones)
        return loss

    def hyperparameter_update(self, decrease_by=0.1):
        for param_group in self._model._optimizer.param_groups:
            param_group["lr"] *= (1 - decrease_by)

    def train_agent(self, num_steps=1000, batch_size=64):

        # Load pre-trained model
        self._model.build(self._start_model)

        #Training loop
        for i in range(num_steps):
            self._model._optimizer.zero_grad()
            gen_SMILESs = []
            scores = []

            #This part basically generates batch of SMILES and calculate the scores of them.
            #score (QED) ranged from 0 to 1. if the generated molecule is invalid, I set it to -1.
            gen_SMILESs = [self._encoder.decode(self._model.sample(self._starting_token, self._T)) for x in range(batch_size)]
            print(f"length of gen_SMILES, {len(gen_SMILESs)}")
            clean_gen_SMILESs = [clean_molecule(s[0], self._model_type) for s in gen_SMILESs]
            print(f"length of clean_SMILES, {len(clean_gen_SMILESs)}")
            scores = [self.score(Chem.MolFromSmiles(smi)) if Chem.MolFromSmiles(smi) else -1 for smi in clean_gen_SMILESs]
            print(f"mean reward = {sum(rewards)/len(rewards)}")

            for x, smiles in enumerate(clean_gen_SMILESs):
                print("Step {}, Generated SMILES No.{}: {}".format(i, x, smiles))
                mol = Chem.MolFromSmiles(smiles)
                if mol:
                    print("QED={}".format(self.score(mol)))
                else:
                    print("Invalid mol")

            loss = self.loss(scores)
            print(f"Current loss: {loss}")

            #decrease learning rate by 10% every 10 steps (just for testing)
            if i % 10 == 0 and i != 0:
                self.hyperparameter_update()
            for param_group in self._model._optimizer.param_groups:
                print("Current learning rate {}".format(param_group["lr"]))
            loss.backward()
            self._model._optimizer.step()

if __name__ == "__main__":
    s = Trainer_rl()
    s.train_agent(num_steps=30, batch_size=10)

Класс BIMODAL из исходного кода:

класс BIMODAL ():

def __init__(self, molecule_size=7, encoding_dim=55, lr=.01, hidden_units=128):

    self._molecule_size = molecule_size
    self._input_dim = encoding_dim
    self._output_dim = encoding_dim
    self._layer = 2
    self._hidden_units = hidden_units

    # Learning rate
    self._lr = lr

    # Build new model
    self._lstm = BiDirLSTM(self._input_dim, self._hidden_units, self._layer)

    # Check availability of GPUs
    self._gpu = torch.cuda.is_available()
    self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        self._lstm = self._lstm.cuda()
        print('GPU available')

    # Adam optimizer
    self._optimizer = torch.optim.Adam(self._lstm.parameters(), lr=self._lr, betas=(0.9, 0.999))
    # Cross entropy loss
    self._loss = nn.CrossEntropyLoss(reduction='mean')

def build(self, name=None):
    """Build new model or load model by name
    :param name:    model name
    """

    if (name is None):
        self._lstm = BiDirLSTM(self._input_dim, self._hidden_units, self._layer)

    else:
        self._lstm = torch.load(name + '.dat', map_location=self._device)

    if torch.cuda.is_available():
        self._lstm = self._lstm.cuda()

    self._optimizer = torch.optim.Adam(self._lstm.parameters(), lr=self._lr, betas=(0.9, 0.999))

def print_model(self):
    '''Print name and shape of all tensors'''
    for name, p in self._lstm.state_dict().items():
        print(name)
        print(p.shape)

def train(self, data, label, epochs=1, batch_size=1):
    '''Train the model
    :param  data:   data array (n_samples, molecule_size, encoding_length)
    :param  label:  label array (n_samples, molecule_size)
    :param  epochs: number of epochs for the training
    :param  batch_size: batch size for the training
    :return statistic:  array storing computed losses (epochs, batch size)
    '''

    # Number of samples
    n_samples = data.shape[0]

    # Change axes from (n_samples, molecule_size, encoding_dim) to (molecule_size, n_samples, encoding_dim)
    data = np.swapaxes(data, 0, 1)

    # Create tensor from label
    label = torch.from_numpy(label).to(self._device)

    # Calculate number of batches per epoch
    if (n_samples % batch_size) is 0:
        n_iter = n_samples // batch_size
    else:
        n_iter = n_samples // batch_size + 1

    # To store losses
    statistic = np.zeros((epochs, n_iter))

    # Prepare model for training
    self._lstm.train()

    # Iteration over epochs
    for i in range(epochs):

        # Iteration over batches
        for n in range(n_iter):

            # Set gradient to zero for batch
            self._optimizer.zero_grad()

            # Store losses in list
            losses = []

            # Compute indices used as batch
            batch_start = n * batch_size
            batch_end = min((n + 1) * batch_size, n_samples)

            # Reset model with correct batch size
            self._lstm.new_sequence(batch_end - batch_start, self._device)

            # Current batch
            batch_data = torch.from_numpy(data[:, batch_start:batch_end, :].astype('float32')).to(self._device)

            # Initialize start and end position of sequence read by the model
            start = self._molecule_size // 2
            end = start + 1

            for j in range(self._molecule_size - 1):
                self._lstm.new_sequence(batch_end - batch_start, self._device)

                # Select direction for next prediction
                if j % 2 == 0:
                    dir = 'right'
                else:
                    dir = 'left'

                # Predict next token
                pred = self._lstm(batch_data[start:end], dir, self._device)

                # Compute loss and extend sequence read by the model
                if j % 2 == 0:
                    loss = self._loss(pred, label[batch_start:batch_end, end])
                    end += 1

                else:
                    loss = self._loss(pred, label[batch_start:batch_end, start - 1])
                    start -= 1

                # Append loss of current position
                losses.append(loss.item())

                # Accumulate gradients
                # (NOTE: This is more memory-efficient than summing the loss and computing the final gradient for the sum)
                loss.backward()

            # Store statistics: loss per token (middle token not included)
            statistic[i, n] = np.sum(losses) / (self._molecule_size - 1)

            # Perform optimization step
            self._optimizer.step()

    return statistic

def validate(self, data, label, batch_size=128):
    ''' Validation of model and compute error
    :param data:    test data (n_samples, molecule_size, encoding_size)
    :param label:   label data (n_samples_molecules_size)
    :param batch_size:  batch size for validation
    :return:            mean loss over test data
    '''

    # Use train mode to get loss consistent with training
    self._lstm.train()

    # Gradient is not compute to reduce memory requirements
    with torch.no_grad():
        # Compute tensor of labels
        label = torch.from_numpy(label).to(self._device)

        # Number of samples
        n_samples = data.shape[0]

        # Change axes from (n_samples, molecule_size, encoding_dim) to (molecule_size , n_samples, encoding_dim)
        data = np.swapaxes(data, 0, 1).astype('float32')

        # Initialize loss for complete validation set
        tot_loss = 0

        # Calculate number of batches per epoch
        if (n_samples % batch_size) is 0:
            n_iter = n_samples // batch_size
        else:
            n_iter = n_samples // batch_size + 1

        for n in range(n_iter):

            # Compute indices used as batch
            batch_start = n * batch_size
            batch_end = min((n + 1) * batch_size, n_samples)

            # Data used in this batch
            batch_data = torch.from_numpy(data[:, batch_start:batch_end, :].astype('float32')).to(self._device)

            # Initialize loss for molecule
            molecule_loss = 0

            # Reset model with correct batch size and device
            self._lstm.new_sequence(batch_end - batch_start, self._device)

            start = self._molecule_size // 2
            end = start + 1

            for j in range(self._molecule_size - 1):
                self._lstm.new_sequence(batch_end - batch_start, self._device)

                # Select direction for next prediction
                if j % 2 == 0:
                    dir = 'right'
                if j % 2 == 1:
                    dir = 'left'

                # Predict next token
                pred = self._lstm(batch_data[start:end], dir, self._device)

                # Extend reading of the sequence
                if j % 2 == 0:
                    loss = self._loss(pred, label[batch_start:batch_end, end])
                    end += 1

                if j % 2 == 1:
                    loss = self._loss(pred, label[batch_start:batch_end, start - 1])
                    start -= 1

                # Sum loss over molecule
                molecule_loss += loss.item()

            # Add loss per token to total loss (start token and end token not counted)
            tot_loss += molecule_loss / (self._molecule_size - 1)

        return tot_loss / n_iter

def sample(self, middle_token, T=1):
    '''Generate new molecule
    :param middle_token:    starting sequence
    :param T:               sampling temperature
    :return molecule:       newly generated molecule (molecule_length, encoding_length)
    '''

    # Prepare model
    self._lstm.eval()

    # Gradient is not compute to reduce memory requirements
    with torch.no_grad():
        # Output array with merged forward and backward directions

        # New sequence
        seq = np.zeros((self._molecule_size, 1, self._output_dim))
        seq[self._molecule_size // 2, 0] = middle_token

        # Create tensor for data and select correct device
        seq = torch.from_numpy(seq.astype('float32')).to(self._device)

        # Define start/end values for reading
        start = self._molecule_size // 2
        end = start + 1

        for j in range(self._molecule_size - 1):
            self._lstm.new_sequence(1, self._device)

            # Select direction for next prediction
            if j % 2 == 0:
                dir = 'right'
            if j % 2 == 1:
                dir = 'left'

            pred = self._lstm(seq[start:end], dir, self._device)

            # Compute new token
            token = self.sample_token(np.squeeze(pred.cpu().detach().numpy()), T)

            # Set new token within sequence
            if j % 2 == 0:
                seq[end, 0, token] = 1.0
                end += 1

            if j % 2 == 1:
                seq[start - 1, 0, token] = 1.0
                start -= 1

    return seq.cpu().numpy().reshape(1, self._molecule_size, self._output_dim)

def sample_token(self, out, T=1.0):
    ''' Sample token
    :param out: output values from model
    :param T:   sampling temperature
    :return:    index of predicted token
    '''
    # Explicit conversion to float64 avoiding truncation errors
    out = out.astype('float64')

    # Compute probabilities with specific temperature
    out_T = out / T
    p = np.exp(out_T) / np.sum(np.exp(out_T))

    # Generate new token at random
    char = np.random.multinomial(1, p, size=1)
    return np.argmax(char)

def save(self, name='test_model'):
    torch.save(self._lstm, name + '.dat')

Класс BiDirLSTM (используется в классе BIMODAL) скопирован из исходного кода.

Класс BiDirLSTM (nn.Module):

def __init__(self, input_dim=110, hidden_dim=256, layers=2):
    super(BiDirLSTM, self).__init__()

    # Dimensions
    self._input_dim = input_dim
    self._hidden_dim = hidden_dim
    self._output_dim = input_dim

    # Number of LSTM layers
    self._layers = layers

    # LSTM for forward and backward direction
    self._blstm = nn.LSTM(input_size=self._input_dim, hidden_size=self._hidden_dim, num_layers=layers, dropout=0.3, bidirectional=True)

    # All weights initialized with xavier uniform
    nn.init.xavier_uniform_(self._blstm.weight_ih_l0)
    nn.init.xavier_uniform_(self._blstm.weight_ih_l1)
    nn.init.orthogonal_(self._blstm.weight_hh_l0)
    nn.init.orthogonal_(self._blstm.weight_hh_l1)

    # Bias initialized with zeros expect the bias of the forget gate
    self._blstm.bias_ih_l0.data.fill_(0.0)
    self._blstm.bias_ih_l0.data[self._hidden_dim:2 * self._hidden_dim].fill_(1.0)

    self._blstm.bias_ih_l1.data.fill_(0.0)
    self._blstm.bias_ih_l1.data[self._hidden_dim:2 * self._hidden_dim].fill_(1.0)

    self._blstm.bias_hh_l0.data.fill_(0.0)
    self._blstm.bias_hh_l0.data[self._hidden_dim:2 * self._hidden_dim].fill_(1.0)

    self._blstm.bias_hh_l1.data.fill_(0.0)
    self._blstm.bias_hh_l1.data[self._hidden_dim:2 * self._hidden_dim].fill_(1.0)

    # Batch normalization (Weights initialized with one and bias with zero)
    self._norm_0 = nn.LayerNorm(self._input_dim, eps=.001)
    self._norm_1 = nn.LayerNorm(2 * self._hidden_dim, eps=.001)

    # Separate linear model for forward and backward computation
    self._wpred = nn.Linear(2 * self._hidden_dim, self._output_dim)
    nn.init.xavier_uniform_(self._wpred.weight)
    self._wpred.bias.data.fill_(0.0)
def _init_hidden(self, batch_size, device):
    '''Initialize hidden states
    :param batch_size:   size of the new batch
           device:       device where new tensor is allocated
    :return: new hidden state
    '''

    return (torch.zeros(2 * self._layers, batch_size, self._hidden_dim).to(device),
            torch.zeros(2 * self._layers, batch_size, self._hidden_dim).to(device))

def new_sequence(self, batch_size=1, device="cpu"):
    '''Prepare model for a new sequence
    :param batch_size:   size of the new batch
           device:       device where new tensor should be allocated
    :return:
    '''
    self._hidden = self._init_hidden(batch_size, device)
    return

def check_gradients(self):
    '''Print gradients'''
    print('Gradients Check')
    for p in self.parameters():
        print('1:', p.grad.shape)
        print('2:', p.grad.data.norm(2))
        print('3:', p.grad.data)

def forward(self, input, next_prediction='right', device="cpu"):
    '''Forward computation
    :param input:  tensor (sequence length, batch size, encoding size)
    :param next_prediction:    new token is predicted for the left or right side of existing sequence
    :param device:  device where computation is executed
    :return pred:   prediction (batch site, encoding size)
    '''

    # If next prediction is appended at the left side, the sequence is inverted such that
    # forward and backward LSTM always read the sequence along the forward and backward direction, respectively.
    if next_prediction == 'left':
        # Reverse copy of numpy array of given tensor
        input = np.flip(input.cpu().numpy(), 0).copy()
        input = torch.from_numpy(input).to(device)

    # Normalization over encoding dimension
    norm_0 = self._norm_0(input)

    # Compute LSTM unit
    out, self._hidden = self._blstm(norm_0, self._hidden)

    # out (sequence length, batch_size, 2 * hidden dim)
    # Get last prediction from forward (0:hidden_dim) and backward direction (hidden_dim:2*hidden_dim)
    for_out = out[-1, :, 0:self._hidden_dim]
    back_out = out[0, :, self._hidden_dim:]

    # Combine predictions from forward and backward direction
    bmerge = torch.cat((for_out, back_out), -1)

    # Normalization over hidden dimension
    norm_1 = self._norm_1(bmerge)

    # Linear unit forward and backward prediction
    pred = self._wpred(norm_1)

    return pred
...