Я пытаюсь решить проблему CartPole с pytorch, но после нескольких итераций параметры не обновляются.
Код, который я пытаюсь воспроизвести, - это [cartpole https://github.com/gsurma/cartpole/blob/master/cartpole.py] сделано в кератах.
import random
from collections import deque
import gym
import numpy as np
import torch.nn as nn
import torch
GAMMA = 0.95
MEMORY_SIZE = 1000000
BATCH_SIZE = 50
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01
EXPLORATION_DECAY = 0.995
# Principal Neural Netword module
class DQNSolver(nn.Module):
def __init__(self, observation_space, action_space):
super(DQNSolver, self).__init__()
self.action_space = action_space
self.observation_space = observation_space
self.hiddenSpace = 24
self.fc1 = nn.Linear(self.observation_space, self.hiddenSpace)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(self.hiddenSpace, self.action_space)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
class cartpole():
def __init__(self):
# Device configuration
device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')
self.env = gym.make("CartPole-v1")
self.observation_space = self.env.observation_space.shape[0]
self.action_space = self.env.action_space.n
self.dqn_solver = DQNSolver(self.observation_space, self.action_space).to(device)
# Create the memory
self.memory = deque(maxlen=MEMORY_SIZE)
self.exploration_rate = EXPLORATION_MAX
# Create the optimizer and the loss
self.optimizer = torch.optim.Adam(self.dqn_solver.parameters(), lr=0.1)
self.loss_func = torch.nn.MSELoss()
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
# This custom function will receive a observation of the state of the
# environment, and then return the action
def predict_action(self, observation):
# if np.random.rand() < self.exploration_rate:
# randomQ = np.random.rand(1, self.action_space)[0]
# return np.double(randomQ)
predicted = self.dqn_solver(observation)
predicted = predicted.cpu().data.numpy()
return np.double(predicted)
def optimize_model(self, state, q_values):
output = self.predict_action(state)
output = torch.tensor(output, requires_grad=False)
qValues = torch.tensor(q_values, requires_grad=True)
self.optimizer.zero_grad()
loss = self.loss_func(output, qValues)
loss.backward()
self.optimizer.step()
print("Loss: {}".format(loss))
def experience_replay(self):
if len(self.memory) < BATCH_SIZE:
return
batch = random.sample(self.memory, BATCH_SIZE)
for state, action, reward, state_next, terminal in batch:
q_update = reward
if not terminal:
next_action = self.predict_action(state_next)
q_update = (reward + GAMMA * np.amax(next_action))
q_values = self.predict_action(state)
q_values[action] = q_update
self.optimize_model(state, q_values)
print("Finished replay")
print('weights after backpropagation = ', list(self.dqn_solver.parameters()))
self.exploration_rate *= EXPLORATION_DECAY
self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)
def run(self):
while True:
state = self.env.reset()
# state = np.reshape(state, [1, self.observation_space])
state = torch.Tensor(state)
while True:
self.env.render()
action = self.predict_action(state)
action = np.argmax(action)
state_next, reward, terminal, info = self.env.step(action)
reward = reward if not terminal else -reward
state_next = torch.Tensor(state_next)
self.remember(state, action, reward, state_next, terminal)
state = state_next
if terminal:
break
self.experience_replay()
if __name__ == "__main__":
cartpole().run()
после нескольких эпох параметры, напечатанные в этой строке, одинаковы:
print('weights after backpropagation = ', list(self.dqn_solver.parameters()))
также значения потерь ближе к случайному значению, но нет увеличивается или уменьшается. Что может быть не так?