DQN - веса не меняются во время тренировки - PullRequest
0 голосов
/ 09 мая 2020

Я построил DQN для изучения Ti c Ta c Toe и выяснил, что веса не меняются во время тренировки. Кажется, что-то не так с реализацией моей цели / прогноза или с потерями / оптимизатором. Я новичок в pytorch, поэтому не знаю, где я напортачил. Вот соответствующий код:

def get_pred_and_target(st, next_state, act, player, discount):

    pred = torch.tensor([net(torch.tensor(st).float()).squeeze().detach().numpy()[act]])
    # Define reward
    reward = 0.
    winner, game_status = check_result(next_state)
    if game_status == 'Done' and winner == player:
        reward = 1.
    if game_status == 'Done' and winner != player:
        reward = -1.
    if game_status == 'Draw':
        reward = 1.
    # Define target
    if next_state.count(0) == 0:
        target = torch.tensor([reward], requires_grad=True).float()
    else:
        target = torch.tensor([reward]).float() + discount * torch.max(
            target_net(torch.tensor(st).float()))
    return pred, target

# Training against intelligent agent
num_epochs = 10000
epsilon_array = np.linspace(0.8, 0.1, num_epochs)  # epsilon decays with every epoch
lr_array = np.linspace(1e-2, 1e-9, num_epochs)
results = []
results_val = []
percentages = []
percentages_val = []
preds = torch.tensor([]).float()
targets = torch.tensor([]).float()
training = True
validation = False
playing = False
batch_size = 5
update_target = 20
for param in net.parameters():
    param.requires_grad = True

if training:
    for epoch in range(num_epochs):
        # Define Optimizer
        optimizer = torch.optim.Adam(net.parameters(), lr=lr_array[epoch], weight_decay=1e-8)

        # Produce batch
        for i in range(batch_size):
            # Clear Board
            state = [0, 0, 0, 0, 0, 0, 0, 0, 0]

            epsilon = epsilon_array[epoch]
            game_status = 'Not Done'
            winner = None
            players_turn = np.random.choice([0, 1])

            while game_status == 'Not Done':
                if players_turn == 0:  # X's move
                    # print("\nAI X's turn!")
                    action = select_action(state, epsilon)
                    new_state = play_move(state, 1, action)
                else:  # O's move
                    # print("\nAI O's turn!")
                    action = select_random_action(state)
                    new_state = play_move(state, -1, action)

                # get pred and target for Q(s,a)
                pred, target = get_pred_and_target(state, new_state, action, 1, discount=0.99)
                # update batch
                preds = torch.cat([preds, pred])
                targets = torch.cat([targets, target])
                # update state
                state = new_state.copy()
                # print_board(new_state)
                winner, game_status = check_result(state)
                if winner is not None:
                    # print(str(winner) + ' won!')
                    if winner == 1:
                        results.append('X')
                    else:
                        results.append('O')
                else:
                    players_turn = (players_turn + 1) % 2
                if game_status == 'Draw':
                    # print('Draw!')
                    results.append('Draw')

        loss = loss_fn(preds, targets)
        optimizer.zero_grad()
        # Backward pass
        loss.backward()
        # Update
        optimizer.step()
        # Clear batch
        preds = torch.tensor([]).float()
        targets = torch.tensor([]).float()

        # update target net
        if epoch % update_target == 0:
            print('Epoch: ' + str(epoch))
            print(torch.mean(loss))
            print(pred)
            target_net = pickle.loads(pickle.dumps(net))
            percentage = results[-700:].count('O')/7
            percentages.append(percentage)
            print(f'Random player win percentage in the last 700 games: {percentage} %')
    print('Training Complete')
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...