Как loss.backward () работает для пакетов? - PullRequest
0 голосов
/ 16 апреля 2020

Так что я сейчас тренирую DDQN, чтобы играть на четверке. В каждом состоянии сеть прогнозирует действие как наилучшее действие и соответственно перемещается. Код в основном выглядит следующим образом:

for epoch in range(num_epochs):
        for i in range(batch_size):
                while game is not finished:
                        action = select_action(state)
                        new_state = play_move(state, action)
                        pred, target = get_pred_target(state, new_state, action)
                        preds = torch.cat([preds, pred])
                        targets = torch.cat([targets, target)]
        loss = loss_fn(preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Во время обучения сеть становится немного лучше, но нигде не так хорошо, как я ожидал. Размышляя об этом, я задаюсь вопросом, правильно ли я реализовал вызов loss.backward (). Суть в том, что я сохраняю все прогнозы и цели для каждого хода в тензорах pred и target. Однако я не отслеживаю состояния, которые привели к этим прогнозам и целям. Но разве это не необходимо для обратного распространения или эта информация каким-то образом сохраняется?

Большое спасибо!

...