пропуская формальности:
Я пытаюсь применить усиленное обучение к модели seq2seq на основе трансформатора (для целей абстрактного суммирования) в Pytorch.
Моя текущая установка выглядит примерно так:
Я получаю жадное распределение (сводку) из модели, выводя по одному токену за раз в al oop
def get_greedy_distribution(model, batch):
src, (shift_tgt, lbl_tgt), segs, clss, mask_src, mask_tgt, mask_cls = batch
# the mock targets are just torch.zeros tensors to store inferred tokens
mock_tgt = get_mock_tgt(shift_tgt)
mock_return = get_mock_tgt(shift_tgt)
max_length = shift_tgt.shape[1]
with torch.no_grad():
for i in range(0, max_length-1):
prediction = model(src, mock_tgt, segs, clss, mask_src, mask_tgt, mask_cls)
prediction = F.softmax(prediction, dim=2)
val, ix = prediction.data.topk(1)
mock_tgt[:, i+1] = ix.squeeze()[:, i].detach()
mock_return[:, i] = ix.squeeze()[:, i].detach()
return mock_return
Я получаю выборочное распределение, с вероятностями, из модели аналогичным образом:
def get_distribution(model, batch):
src, (shift_tgt, lbl_tgt), segs, clss, mask_src, mask_tgt, mask_cls = batch
mock_tgt = get_mock_tgt(shift_tgt)
mock_return = get_mock_tgt(shift_tgt)
max_length = shift_tgt.shape[1]
log_probs = []
for i in range(0, max_length-1):
prediction = model(src, mock_tgt, segs, clss, mask_src, mask_tgt, mask_cls)
prediction = F.softmax(prediction, dim=2)
multi_dist = Categorical(prediction[:, i])
x_t = multi_dist.sample()
log_prob = multi_dist.log_prob(x_t)
mock_tgt[:, i+1] = x_t
mock_return[:, i] = x_t
log_probs.append(log_prob)
return mock_return, log_probs
Однако я немного не уверен, правильно ли я понял распределение выборки. Это хорошо работает в контексте RNN, где я могу выбирать логиты и вероятности во время типичного RNN l oop, но при использовании трансформатора это выглядит немного неправильно.
Как бы вы предложили обратиться к трансформатору для Типичная базовая выборка обучения подкреплению обучения (я предполагаю, что это градиент политики)? Код Pytorch предпочтительнее, но если у вас есть примеры Tensorflow, я уверен, что смогу разобраться.