DQN: ошибка для loss.backward () в состоянии терминала с фиксированной целью - PullRequest
0 голосов
/ 14 апреля 2020

Я сейчас пытаюсь DQN для Ti c Ta c Toe. Я застрял со следующей ошибкой:

элемент 0 тензоров не требует grad и не имеет grad_fn

Я думаю, что ошибка возникает при каждой игре до конца и достигает состояния терминала, поскольку цель затем вручную устанавливается на некоторое значение, и теперь градиент не может быть вычислен во время обратного распространения. Как я могу это исправить?

Вот мой код для обновления NN:

def update_NN(state, next_state, action, player, discount, lr, loss_all):
    pred = torch.tensor(net(torch.tensor(state).float().view(-1, 9)).squeeze().detach().numpy()[action])
    reward = 0
    winner, game_status = check_result(next_state)
    if game_status == 'Done' and winner == player:
        reward = 100
    if game_status == 'Done' and winner != player:
        reward = -1
    if game_status == 'Draw':
        reward = 10

    if next_state.count(0) == 0:
        target = torch.tensor(reward).float()
    else:
        target = torch.tensor(reward).float() + discount * torch.max(net(torch.tensor(next_state).float()))
    # Evaluate loss
    loss = loss_fn(pred, target)
    print(loss)
    loss_all.append(loss)
    optimizer.zero_grad()
    # Backward pass
    loss.backward()
    # Update
    optimizer.step()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...