Пакетное обучение Pytorch Denoising Autoencoder с LSTM - PullRequest
0 голосов
/ 11 февраля 2020

У меня есть этот код для обучения шумоподавляющего автоэнкодера, который использует LSTM для кодера и декодера и работает с именами

def denoise_train(x: DataLoader):
    loss = 0.
    noisy_x = list(map(lambda s: noise_name(s), x))

    rnn_x = to_rnn_tensor(x, DECODER_COUNT)
    rnn_noisy_x = to_rnn_tensor(noisy_x, ENCODER_COUNT)

    encoder_hidden = encoder.init_hidden(batch_size=BATCH_SZ)

    for i in range(rnn_noisy_x.shape[0]):
        _, encoder_hidden = encoder(rnn_noisy_x[i].unsqueeze(0), encoder_hidden)

    decoder_input = strings_to_tensor([SOS] * BATCH_SZ)

    decoder_hidden = encoder_hidden

    name = ''

    for i in range(rnn_x.shape[0]):

        decoder_probs, decoder_hidden = decoder(decoder_input, decoder_hidden)

        _, nonzero_indexes = rnn_x[i].topk(1)

        # TODO!!! Need to fix rest of code for batch

        best_index = torch.argmax(decoder_probs, dim=2).item()

        loss += criterion(decoder_probs[0], nonzero_indexes[0])

        name += ALL_CHARS[best_index]

        decoder_input = torch.zeros(1, 1, LETTERS_COUNT)

        decoder_input[0, 0, best_index] = 1.

    loss.backward()
    return name, noisy_x, loss.item()

x, который передается в параметре функции, является следующей итерацией итератора (DataLoader). Главное, что я пытаюсь сделать, - это получить argmax для всех decoder_probs, которые выходят в пакетном режиме, размер имени файла x размер пакета x длина вывода. Поэтому мне нужно, чтобы best_index был argmax для всех записей в пакете, а decoder_input должен быть размером 1xbatch x, где все лучшие символы = 1. Как получить argmax для всего пакета в тензоре decoder_probs?

...