Объяснение обучения с подкреплением в Pytorch - PullRequest
0 голосов
/ 04 августа 2020

Я пытаюсь обучиться с подкреплением, чтобы создать шахматный движок. Я нашел код в сети Python с Pytorch. Я больше знаком с tenorflow, поэтому может ли кто-нибудь помочь объяснить, что означает код, и, возможно, помочь «перевести» его в код tensorflow?

Кроме того, определенная в настоящее время модель «достаточно устойчива», чтобы быть достойной? Если нет, как я могу это улучшить? Я попытался добавить к нему больше nn.Linear линий, но, похоже, это ослабило модель.

EDIT: вот исходная ссылка на код: https://colab.research.google.com/drive/1Xk9MibJ9Fli5tIlDvo88hcZrI76rqZN5#scrollTo = ZxHEghUq9JWM

EDIT2: Кроме того, я хочу выяснить, как реализовать систему, которая будет использовать память предыдущих результатов в текущем коде. Может ли кто-нибудь указать мне в правильном направлении?

Обратите внимание, что приведенный ниже код был АДАПТИРОВАН по ссылке. Адаптированный код:

import chess
import chess.pgn
import chess.engine
import torch
import numpy as np
import os
import torch.nn as nn
from torch.nn import functional as F
#os.remove("Games.txt")
def board_to_tensor(board):  
    # Python chess uses flattened representation of the board
    x = torch.zeros(64, dtype=torch.float)
    for pos in range(64):
        piece = board.piece_type_at(pos)
        if piece:
            color = int(bool(board.occupied_co[chess.BLACK] & chess.BB_SQUARES[pos]))
            col = int(pos % 8)
            row = int(pos / 8)
            x[row * 8 + col] = -piece if color else piece
    x = x.reshape(8, 8)        
    return x


def move_to_index_tensor(move):
    index_tensor = torch.LongTensor([0])
    square_to_pick_figure = move.from_square
    # Can decode exact position this way:
    #square_to_pick_figure_row = int(square_to_pick_figure / 8)
    #square_to_pick_figure_col = int(squre_to_pick_figure % 8)
    square_to_put_figure = move.to_square
    # Can decode exact position this way:
    #square_to_put_figure_row = int(square_to_put_figure / 8)
    #square_to_put_figure_col = int(square_to_put_figure % 8)
    index = square_to_pick_figure * 64 + square_to_put_figure
    index_tensor = torch.LongTensor([index])
    return index_tensor
  
def filter_legal_moves(legal_moves):
    filtered_legal_moves = []
    for legal_move in legal_moves:
        # Here we check if it is a promotion and
        # only leave promotion if it is a promotion to a queen
        if legal_move.promotion is not None:
            if legal_move.promotion == 5:
                filtered_legal_moves.append(legal_move)
            continue
        filtered_legal_moves.append(legal_move)
    return filtered_legal_moves


def legal_moves_to_index_tensors(legal_moves):
    legal_moves_index_tensors = [move_to_index_tensor(legal_move) for legal_move in legal_moves]
    return legal_moves_index_tensors

# The input to the network is a tensor of size 8*8 (it is flattened)
# The output of the network is 64*64 (it is flattened too)
# The size of the hidden layer should be 512

class Network(nn.Module):
    def __init__(self, number_of_actions=64*64):
        super(Network, self).__init__()
        
        # Fill up the values below in nn.Linear()
        self.layer1 = nn.Linear(64, 512)
        self.layer5 = nn.Linear(512, number_of_actions)
        
        # Initialization of weights in the layers
        nn.init.xavier_uniform_(self.layer1.weight)
        nn.init.xavier_uniform_(self.layer5.weight)
        
                
    def forward(self, x):
        x =  F.relu( self.layer1(x) ) 
        # Logits will be fed into softmax layer to get probabilities
        # for each move later.
        logits =  self.layer5(x)
        
        return logits

net = Network(number_of_actions=64*64)

def discount_rewards(collected_moves, gamma=0.99):
    running_reward = 0.0
    for index, collected_move in enumerate(reversed(collected_moves)):
        reward = collected_move[1]
        running_reward = running_reward * gamma + reward
        collected_move[1] = running_reward
  

def normalize_rewards(collected_moves):
    normalized_rewards = np.asarray(list(map(lambda x: x[1], collected_moves)), dtype=np.float)
    normalized_rewards -= np.mean(normalized_rewards)
    normalized_rewards /= np.std(normalized_rewards)

    for index, collected_move in enumerate(collected_moves):
        collected_move[1] = normalized_rewards[index]

from torch.distributions import Categorical
import random


def get_games_data(policy_net, episodes=100,d=1):
    all_moves = [] 
    lost_count = 0
    draw_count = 0
    win_count = 0
    game_lengths_sum = 0.0
    
    for episode in range(episodes):
        
        engine = chess.engine.SimpleEngine.popen_uci("stockfish-5-linux/Linux/stockfish_14053109_x64")
        engine.configure({"Clear Hash": True})
        
        board = chess.Board(fen='rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1')
        collected_moves = []
        board_sign = 1
        move_counter = 0.0
        all = ""
        while not board.is_game_over():
            if board_sign == 1:
                # Converting the board to a tensor representation
                board_tensor = board_to_tensor(board).reshape(-1)
                board_tensor_batched = board_tensor.unsqueeze(0)
                # Getting the logits output
                logits = policy_net(board_tensor_batched)
                # Now we need to select only legal moves
                # in python-chess format
                current_legal_moves = filter_legal_moves(list(board.legal_moves))
                legal_moves_index_tensors = legal_moves_to_index_tensors(current_legal_moves)
                legal_moves_logits = logits[:, legal_moves_index_tensors]
                # Here we sample the action using valid logits
                categorical_sampler = Categorical(logits=(legal_moves_logits))
                sampled_action = categorical_sampler.sample()
                sampled_action_move_object = current_legal_moves[sampled_action]
                log_prob = categorical_sampler.log_prob(sampled_action)
                board.push(sampled_action_move_object)
                #print(str(sampled_action_move_object),end = " ")
                all += str(sampled_action_move_object)
                # Board tensor, legal_moves_indexes, sampled_action_label, reward (0 if not known yet)
                collected_moves.append([log_prob, 0.0])
            else:
                result = engine.play(board, chess.engine.Limit(depth=d, nodes=3))
                board.push(result.move)
                all+=" "
                all+= str(result.move)
                all+=" "
                #print(str(result.move),end=" ")

            board_sign = board_sign * -1
            move_counter = move_counter + 1

        #print(f"\n--------------{board.result()}---------------\n")
        with open("Games.txt","a") as file:
          file.write(all)
          file.write(str(board.result()))
          file.write("\n\n\n")
        if board.is_checkmate():
          if board_sign == 1:
            reward = -1.0
            lost_count = lost_count + 1
          else:
            reward = 1.0
            win_count = win_count + 1
        if not board.is_checkmate():
          reward = 0.1
          draw_count = draw_count + 1
        game_lengths_sum = game_lengths_sum + move_counter    
        collected_moves[-1][1] = reward       
        discount_rewards(collected_moves)
        all_moves.extend(collected_moves)    
        engine.quit()
    
    average_game_length = game_lengths_sum / episodes
    stats = { "lost": lost_count,
              "draw": draw_count,
              "win": win_count,
            }
    normalize_rewards(all_moves)
    return all_moves, stats, win_count

import torch.optim as optim

net = Network(number_of_actions=64*64)
optimizer = optim.Adam(net.parameters(), lr=0.01)

checkpoint = torch.load("/content/drive/My Drive/Checkpoint2Layer.pt")
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

from livelossplot import PlotLosses
liveloss = PlotLosses()
# Stop it when you are happy with the displayed results
# Each iteration takes a while, be patient
dep=1
while dep<11:
  try:
    collected_moves, stats, wins = get_games_data(policy_net=net, episodes=100,d=dep)
    if int(wins)>70:
      dep+=1
    logs = [collected_move[0] for collected_move in collected_moves]
    rewards = [collected_move[1] for collected_move in collected_moves]
    logs_tensor = torch.cat(logs)
    rewards_tensor = torch.FloatTensor(rewards)
    optimizer.zero_grad()
    policy_loss = -logs_tensor * rewards_tensor
    policy_loss = policy_loss.sum()
    policy_loss.backward()    
    optimizer.step()  
    liveloss.update(stats)
    liveloss.draw()
  except KeyboardInterrupt:
    break
torch.save({
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, "Checkpoint.pt")

#torch.save(net.state_dict(),"ChessNet.pt")
...