Модель Pytorch не обновляет вес - PullRequest
0 голосов
/ 10 октября 2019

Я пытаюсь решить проблему CartPole с pytorch, но после нескольких итераций параметры не обновляются.

Код, который я пытаюсь воспроизвести, - это [cartpole https://github.com/gsurma/cartpole/blob/master/cartpole.py] сделано в кератах.

import random
from collections import deque

import gym
import numpy as np
import torch.nn as nn
import torch

GAMMA = 0.95

MEMORY_SIZE = 1000000
BATCH_SIZE = 50

EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995


# Principal Neural Netword module
class DQNSolver(nn.Module):
    def __init__(self, observation_space, action_space):
        super(DQNSolver, self).__init__()
        self.action_space = action_space
        self.observation_space = observation_space

        self.hiddenSpace = 24

        self.fc1 = nn.Linear(self.observation_space, self.hiddenSpace)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(self.hiddenSpace, self.action_space)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out


class cartpole():

    def __init__(self):

        # Device configuration
        device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')

        self.env = gym.make("CartPole-v1")
        self.observation_space = self.env.observation_space.shape[0]
        self.action_space = self.env.action_space.n

        self.dqn_solver = DQNSolver(self.observation_space, self.action_space).to(device)

        # Create the memory
        self.memory = deque(maxlen=MEMORY_SIZE)
        self.exploration_rate = EXPLORATION_MAX

        # Create the optimizer and the loss
        self.optimizer = torch.optim.Adam(self.dqn_solver.parameters(), lr=0.1)
        self.loss_func = torch.nn.MSELoss()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    # This custom function will receive a observation of the state of the
    # environment, and then return the action
    def predict_action(self, observation):
        # if np.random.rand() < self.exploration_rate:
        #     randomQ = np.random.rand(1, self.action_space)[0]
        #     return np.double(randomQ)
        predicted = self.dqn_solver(observation)
        predicted = predicted.cpu().data.numpy()
        return np.double(predicted)

    def optimize_model(self, state, q_values):
        output = self.predict_action(state)
        output = torch.tensor(output, requires_grad=False)
        qValues = torch.tensor(q_values, requires_grad=True)

        self.optimizer.zero_grad()
        loss = self.loss_func(output,  qValues)
        loss.backward()
        self.optimizer.step()


        print("Loss: {}".format(loss))

    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        for state, action, reward, state_next, terminal in batch:
            q_update = reward
            if not terminal:
                next_action = self.predict_action(state_next)
                q_update = (reward + GAMMA * np.amax(next_action))
            q_values = self.predict_action(state)
            q_values[action] = q_update

            self.optimize_model(state, q_values)
        print("Finished replay")
        print('weights after backpropagation = ',   list(self.dqn_solver.parameters()))
        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)

    def run(self):
        while True:
            state = self.env.reset()
            # state = np.reshape(state, [1, self.observation_space])
            state = torch.Tensor(state)
            while True:
                self.env.render()
                action = self.predict_action(state)
                action = np.argmax(action)
                state_next, reward, terminal, info = self.env.step(action)
                reward = reward if not terminal else -reward
                state_next = torch.Tensor(state_next)
                self.remember(state, action, reward, state_next, terminal)
                state = state_next
                if terminal:
                    break
                self.experience_replay()


if __name__ == "__main__":
    cartpole().run()

после нескольких эпох параметры, напечатанные в этой строке, одинаковы:

print('weights after backpropagation = ',   list(self.dqn_solver.parameters()))

также значения потерь ближе к случайному значению, но нет увеличивается или уменьшается. Что может быть не так?

1 Ответ

0 голосов
/ 10 октября 2019

Задача

В optimize_model вы заменяете q_value новым тензором, который не является узлом исходного вычислительного графа, поэтому градиенты не могут быть переданы обратно через сеть. См. Пример ниже:

import torch
import torch.nn as nn
import torch.optim as optim

clf = nn.Linear(2, 2)
opt = optim.SGD(clf.parameters(), lr=0.1)
crit = nn.MSELoss()

input = torch.arange(2).float().view(-1, 2)
label = torch.arange(2).float().view(-1, 2)

pred = clf(input)
pred_copy = torch.tensor(pred, requires_grad=True)

opt.zero_grad()
loss_wrong = crit(pred_copy, label)
loss_wrong.backward()
for p in clf.parameters():
    print(p.grad)

opt.zero_grad()
loss_correct = crit(pred, label)
loss_correct.backward()
for p in clf.parameters():
    print(p.grad)

Вывод:

None
None
tensor([[-0.0000, -0.5813],
        [-0.0000, -0.9274]])
tensor([-0.5813, -0.9274])

Из приведенного выше быстрого примера вы можете видеть, что градиенты, вычисленные из скопированного прогноза, не могут быть переданы сетевым параметрам.


Решение

def optimize_model(self, state, q_values):
    output = self.predict_action(state)
    output.requires_grad = False
    qValues.requires_grad = True

Кроме того, как упомянуто @JoshVarty в комментариях, вы не должны преобразовывать какой-либо тензор в вычислительном графе в numpy и преобразовывать его обратно. Это сломало бы график, и, таким образом, градиенты не будут переданы (вычислены) правильно.

tl; dr Только используя как можно больше встроенных функций Pytorch.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...