Я сейчас пытаюсь DQN для Ti c Ta c Toe. Я застрял со следующей ошибкой:
элемент 0 тензоров не требует grad и не имеет grad_fn
Я думаю, что ошибка возникает при каждой игре до конца и достигает состояния терминала, поскольку цель затем вручную устанавливается на некоторое значение, и теперь градиент не может быть вычислен во время обратного распространения. Как я могу это исправить?
Вот мой код для обновления NN:
def update_NN(state, next_state, action, player, discount, lr, loss_all):
pred = torch.tensor(net(torch.tensor(state).float().view(-1, 9)).squeeze().detach().numpy()[action])
reward = 0
winner, game_status = check_result(next_state)
if game_status == 'Done' and winner == player:
reward = 100
if game_status == 'Done' and winner != player:
reward = -1
if game_status == 'Draw':
reward = 10
if next_state.count(0) == 0:
target = torch.tensor(reward).float()
else:
target = torch.tensor(reward).float() + discount * torch.max(net(torch.tensor(next_state).float()))
# Evaluate loss
loss = loss_fn(pred, target)
print(loss)
loss_all.append(loss)
optimizer.zero_grad()
# Backward pass
loss.backward()
# Update
optimizer.step()