У меня есть этот код для обучения шумоподавляющего автоэнкодера, который использует LSTM для кодера и декодера и работает с именами
def denoise_train(x: DataLoader):
loss = 0.
noisy_x = list(map(lambda s: noise_name(s), x))
rnn_x = to_rnn_tensor(x, DECODER_COUNT)
rnn_noisy_x = to_rnn_tensor(noisy_x, ENCODER_COUNT)
encoder_hidden = encoder.init_hidden(batch_size=BATCH_SZ)
for i in range(rnn_noisy_x.shape[0]):
_, encoder_hidden = encoder(rnn_noisy_x[i].unsqueeze(0), encoder_hidden)
decoder_input = strings_to_tensor([SOS] * BATCH_SZ)
decoder_hidden = encoder_hidden
name = ''
for i in range(rnn_x.shape[0]):
decoder_probs, decoder_hidden = decoder(decoder_input, decoder_hidden)
_, nonzero_indexes = rnn_x[i].topk(1)
# TODO!!! Need to fix rest of code for batch
best_index = torch.argmax(decoder_probs, dim=2).item()
loss += criterion(decoder_probs[0], nonzero_indexes[0])
name += ALL_CHARS[best_index]
decoder_input = torch.zeros(1, 1, LETTERS_COUNT)
decoder_input[0, 0, best_index] = 1.
loss.backward()
return name, noisy_x, loss.item()
x, который передается в параметре функции, является следующей итерацией итератора (DataLoader). Главное, что я пытаюсь сделать, - это получить argmax для всех decoder_probs, которые выходят в пакетном режиме, размер имени файла x размер пакета x длина вывода. Поэтому мне нужно, чтобы best_index был argmax для всех записей в пакете, а decoder_input должен быть размером 1xbatch x, где все лучшие символы = 1. Как получить argmax для всего пакета в тензоре decoder_probs?