Как применить вознаграждение / рассчитать потери на этапе обучения в режиме подкрепления? - PullRequest
0 голосов
/ 09 февраля 2019

Я создал простую нейронную сеть с pytorch, предназначенную для расчета движений юнитов внутри сетки.

actions = [
    'none',
    'left',
    'right',
    'up',
    'down'
]

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()

        hl_dim = 32
        self.hidden = (
            torch.zeros(2, 1, hl_dim),
            torch.zeros(2, 1, hl_dim)
        )
        in_len = 4
        hl_len = 32
        ou_len = len(actions)
        self.in1 = nn.Linear(in_len, hl_len)
        self.hl1 = nn.LSTM(hl_len, hl_dim, 2, dropout=0.05)
        self.ou1 = nn.Linear(hl_len, ou_len)

    def forward(self, input):
        output = F.relu(self.in1(input.view(1, -1)).unsqueeze(1))
        output, self.hidden = self.hl1(output, self.hidden)
        output = F.relu(self.ou1(output))
        return output


model = Network()
model.zero_grad()

После каждого шага / решения рассчитывается вознаграждение за движение и обновляется мир.

# initial world
pos_x = 6
pos_y = 6
goal_x = 10
goal_y = 10

while True:
    old_state = torch.tensor([pos_x, pos_y, goal_x, goal_y], dtype=torch.float)
    result = model(old_state)

    action = torch.argmax(result)
    action_str = actions[action]

    # calculate world updates...

    terminal = # goal found, 0 or 1
    reward = # in range [-1, +1]
    new_state = torch.tensor([new_pos_x, new_pos_y, goal_x, goal_y], dtype=torch.float)

На основании принятого решения у меня теперь есть old_state, result, new_state, terminal и reward.

Теперь я хочу рассчитать убыток и запустить оптимизатор на основе этих значений.Я предполагаю, что могу использовать для этого общие функции потери и оптимизатора pytorch, которые я предварительно объявил следующим образом:

learningrate = 0.01
loss_function = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=learningrate)

Является ли это правильным подходом для оптимизации в этом случае?И как я могу сгенерировать параметры, необходимые для вызова функции потерь?

...