Я реализую DDPG и застрял, тренируя свои две сети.
Всего у меня есть 4 сети: actor, actor_target, критик и crit_target.
Я обучаю актера и критика в цикле обучения и делаю мягкие обновления для двух других сетей:
def update_weights(self, source, tau):
for target, source in zip(self.parameters(), source.parameters()):
target.data.copy_(tau * source.data + (1 - tau) * target.data)
Мой тренировочный цикл выглядит так:
tensor_next_states = torch.tensor(next_states).view(-1, 1)
prediction_target = self.actor_target(tensor_next_states).data.numpy()
target_critic_output = self.critic_target(
construct_tensor(next_states, prediction_target))
y = torch.tensor(rewards).view(-1,1) + \
self.gamma * target_critic_output
output_critic = self.critic(
torch.tensor(construct_tensor(states, actions), dtype=torch.float))
# compute loss and update critic
self.critic.zero_grad()
loss_critic = self.criterion_critic(y, output_critic)
loss_critic.backward()
self.critic_optim.step()
# compute loss and update actor
tensor_states = torch.tensor(states).view(-1, 1)
ouput_actor = self.actor(tensor_states).data.numpy()
self.actor.zero_grad()
loss_actor = (-1.) * \
self.critic(construct_tensor(states, ouput_actor)).mean()
loss_actor.backward()
self.actor_optim.step()
# update target
self.actor_target.update_weights(self.actor, self.tau)
self.critic_target.update_weights(self.critic, self.tau)
с использованием SGD
в качестве оптимизатора и self.criterion_critic = F.mse_loss
.
construct_tensor(a,b)
создает тензор, подобный [a[0], b[0], a[1], b[1], ...]
.
Я заметил, что RMSE на тестовом наборе до и после тренировки одинаков.
Поэтому я много отлаживал и заметил в update_weights
, что веса обученной сети и целевой сети одинаковы - поэтому я пришел к выводу, что обучение не влияет на веса обученной сети.
Я уже проверил, что вычисленные потери не равны нулю, но все еще с плавающей запятой, проверил замену вызовов zero_grad()
и перемещение вычисленных потерь в self
, что не оказало никакого влияния.
Кто-нибудь уже встречал это поведение и / или есть какие-либо советы или знает, как это исправить?
Обновление: Полный код:
import datetime
import random
from collections import namedtuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
def combine_tensors(s, a):
"""
Combines the two given tensors
:param s: tensor1
:param a: tensor2
:return: combined tensor
"""
target = []
if not len(a[0].shape) == 0:
for i in range(len(s)):
target.append(torch.cat((s[i], a[i])).data.numpy())
else:
for i in range(len(s)):
target.append(torch.cat((s[i], a[i].float().view(-1))) \
.data.numpy())
return torch.tensor(target, device=device)
class actor(nn.Module):
"""
Actor - gets a state (2-dim) and returns probabilities about which
action to take (4 actions -> 4 outputs)
"""
def __init__(self):
super(actor, self).__init__()
# define net structure
self.input_layer = nn.Linear(2, 4)
self.hidden_layer_1 = nn.Linear(4, 8)
self.hidden_layer_2 = nn.Linear(8, 16)
self.hidden_layer_3 = nn.Linear(16, 32)
self.output_layer = nn.Linear(32, 4)
# initialize them
nn.init.xavier_uniform_(self.input_layer.weight)
nn.init.xavier_uniform_(self.hidden_layer_1.weight)
nn.init.xavier_uniform_(self.hidden_layer_2.weight)
nn.init.xavier_uniform_(self.hidden_layer_3.weight)
nn.init.xavier_uniform_(self.output_layer.weight)
nn.init.constant_(self.input_layer.bias, 0.1)
nn.init.constant_(self.hidden_layer_1.bias, 0.1)
nn.init.constant_(self.hidden_layer_2.bias, 0.1)
nn.init.constant_(self.hidden_layer_3.bias, 0.1)
nn.init.constant_(self.output_layer.bias, 0.1)
def forward(self, state):
state = F.relu(self.input_layer(state))
state = F.relu(self.hidden_layer_1(state))
state = F.relu(self.hidden_layer_2(state))
state = F.relu(self.hidden_layer_3(state))
state = F.softmax(self.output_layer(state), dim=0)
return state
class critic(nn.Module):
"""
Critic - gets a state (2-dim) and an action and returns value
"""
def __init__(self):
super(critic, self).__init__()
# define net structure
self.input_layer = nn.Linear(3, 8)
self.hidden_layer_1 = nn.Linear(8, 16)
self.hidden_layer_2 = nn.Linear(16, 32)
self.hidden_layer_3 = nn.Linear(32, 16)
self.output_layer = nn.Linear(16, 1)
# initialize them
nn.init.xavier_uniform_(self.input_layer.weight)
nn.init.xavier_uniform_(self.hidden_layer_1.weight)
nn.init.xavier_uniform_(self.hidden_layer_2.weight)
nn.init.xavier_uniform_(self.hidden_layer_3.weight)
nn.init.xavier_uniform_(self.output_layer.weight)
nn.init.constant_(self.input_layer.bias, 0.1)
nn.init.constant_(self.hidden_layer_1.bias, 0.1)
nn.init.constant_(self.hidden_layer_2.bias, 0.1)
nn.init.constant_(self.hidden_layer_3.bias, 0.1)
nn.init.constant_(self.output_layer.bias, 0.1)
def forward(self, state_, action_):
state_ = combine_tensors(state_, action_)
state_ = F.relu(self.input_layer(state_))
state_ = F.relu(self.hidden_layer_1(state_))
state_ = F.relu(self.hidden_layer_2(state_))
state_ = F.relu(self.hidden_layer_3(state_))
state_ = self.output_layer(state_)
return state_
Transition = namedtuple('Transition',
('state', 'action', 'next_state', 'reward'))
class ReplayMemory(object):
"""
Memory
"""
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(self, *args):
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
def compute_action(actor_trainined, state, eps=0.1):
"""
Computes an action given the actual policy, the state and an eps.
Eps is resposible for the amount of exploring
:param actor_trainined: actual policy
:param state:
:param eps: float in [0,1]
:return:
"""
denoise = random.random()
if denoise > eps:
action_probs = actor_trainined(state.float())
return torch.argmax(action_probs).view(1).int()
else:
return torch.randint(0, 4, (1,)).view(1).int()
def compute_next_state(_action, _state):
"""
Computes the next state given an action and a state
:param _action:
:param _state:
:return:
"""
state_ = _state.clone()
if _action.item() == 0:
state_[1] += 1
elif _action.item() == 1:
state_[1] -= 1
elif _action.item() == 2:
state_[0] -= 1
elif _action.item() == 3:
state_[0] += 1
return state_
def update_weights(target, source, tau):
"""
Soft-Update of weights
:param target:
:param source:
:param tau:
:return:
"""
for target, source in zip(target.parameters(), source.parameters()):
target.data.copy_(tau * source.data + (1 - tau) * target.data)
def update(transition__, replay_memory, batch_size_, gamma_):
"""
Performs one update step
:param transition__:
:param replay_memory:
:param batch_size_:
:param gamma_:
:return:
"""
replay_memory.push(*transition__)
if replay_memory.__len__() < batch_size_:
return
transitions = replay_memory_.sample(batch_size_)
batch = Transition(*zip(*transitions))
states = torch.stack(batch.state)
actions = torch.stack(batch.action)
rewards = torch.stack(batch.reward)
next_states = torch.stack(batch.next_state)
action_target = torch.argmax(actor_target(next_states.float()), 1).int()
y = (
rewards.float().view(-1, 1) +
gamma_ * critic_target(next_states.float(), action_target.float())
.float()
)
critic_trained.zero_grad()
crit_ = critic_trained(states.float(), actions.float())
# nn stuff does not work here! -> doing mse myself..
# loss_critic = (torch.sum((y.float() - crit_.float()) ** 2.)
# / y.data.nelement())
loss_critic = F.l1_loss(y.float(), crit_.float())
loss_critic.backward()
optimizer_critic.step()
actor_trained.zero_grad()
loss_actor = ((-1.) * critic_trained(states.float(),
torch.argmax(
actor_trained(states.float()), 1
).int().float())).mean()
loss_actor.backward()
optimizer_actor.step()
def get_eps(epoch):
"""
Computes the eps for action choosing dependant on the epoch
:param epoch: number of epoch
:return:
"""
if epoch <= 10:
eps_ = 1.
elif epoch <= 20:
eps_ = 0.8
elif epoch <= 40:
eps_ = 0.6
elif epoch <= 60:
eps_ = 0.4
elif epoch <= 80:
eps_ = 0.2
else:
eps_ = 0.1
return eps_
def compute_reward_2(state_, next_state_, terminal_state_):
"""
Better (?) reward function that "compute_reward"
If next_state == terminal_state -> reward = 100
If next_state illegal -> reward = -100
if next_state is further away from terminal_state than state_ -> -2
else 1
:param state_:
:param next_state_:
:param terminal_state_:
:return:
"""
if torch.eq(next_state_, terminal_state_).all():
reward_ = 100
elif torch.eq(next_state_.abs(), 15).any():
reward_ = -100
else:
if (state_.abs() > next_state_.abs()).any():
reward_ = 1.
else:
reward_ = -2
return torch.tensor(reward_, device=device, dtype=torch.float)
def compute_reward(next_state_, terminal_state_):
"""
Computes some reward
:param next_state_:
:param terminal_state_:
:return:
"""
if torch.eq(next_state_, terminal_state_).all():
return torch.tensor(100., device=device, dtype=torch.float)
elif next_state_[0] == 15 or next_state_[1] == 15:
return torch.tensor(-100., device=device, dtype=torch.float)
else:
return (-1.) * next_state_.abs().sum().float()
def fill_memory_2():
"""
Fills the memory with random transitions which got a "good" action chosen
"""
terminal_state_ = torch.tensor([0, 0], device=device, dtype=torch.int)
while replay_memory_.__len__() < batch_size:
state_ = torch.randint(-4, 4, (2,)).to(device).int()
if state_[0].item() == 0 and state_[1].item == 0:
continue
# try to find a "good" action
if state_[0].item() == 0:
if state_[1].item() > 0:
action_ = torch.tensor(1, device=device, dtype=torch.int)
else:
action_ = torch.tensor(0, device=device, dtype=torch.int)
elif state_[1].item() == 0:
if state_[0].item() > 0:
action_ = torch.tensor(2, device=device, dtype=torch.int)
else:
action_ = torch.tensor(3, device=device, dtype=torch.int)
else:
random_bit = random.random()
if random_bit > 0.5:
if state_[1].item() > 0:
action_ = torch.tensor(1, device=device, dtype=torch.int)
else:
action_ = torch.tensor(0, device=device, dtype=torch.int)
else:
if state_[0].item() > 0:
action_ = torch.tensor(2, device=device, dtype=torch.int)
else:
action_ = torch.tensor(3, device=device, dtype=torch.int)
action_ = action_.view(1).int()
next_state_ = compute_next_state(action_, state_)
reward_ = compute_reward_2(state_, next_state_, terminal_state_)
transition__ = Transition(state=state_, action=action_,
reward=reward_, next_state=next_state_)
replay_memory_.push(*transition__)
def fill_memory():
"""
Fills the memory with random transitions
"""
while replay_memory_.__len__() < batch_size:
state_ = torch.randint(-14, 15, (2,)).to(device).int()
if state_[0].item() == 0 and state_[1].item == 0:
continue
terminal_state_ = torch.tensor([0, 0], device=device, dtype=torch.int)
action_ = torch.randint(0, 4, (1,)).view(1).int()
next_state_ = compute_next_state(action_, state_)
reward_ = compute_reward_2(state_, next_state_, terminal_state_)
transition__ = Transition(state=state_, action=action_,
reward=reward_, next_state=next_state_)
replay_memory_.push(*transition__)
if __name__ == '__main__':
# get device if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# set seed
seed_ = 0
random.seed(seed_) # seed of python
if device == "cuda":
# cuda seed
torch.cuda.manual_seed(seed_)
else:
# cpu seed
torch.manual_seed(seed_)
# initialize the nets
actor_trained = actor().to(device)
actor_target = actor().to(device)
# copy -> _trained eqaul _target
actor_target.load_state_dict(actor_trained.state_dict())
optimizer_actor = optim.RMSprop(actor_trained.parameters())
# move them to the device
critic_trained = critic().to(device)
critic_target = critic().to(device)
critic_target.load_state_dict((critic_trained.state_dict()))
actor_target.load_state_dict((actor_trained.state_dict()))
# used optimizer
optimizer_critic = optim.RMSprop(critic_trained.parameters(),
momentum=0.9, weight_decay=0.001)
# replay memory
capacity_replay_memory = 16384
replay_memory_ = ReplayMemory(capacity_replay_memory)
# hyperparams
batch_size = 1024
gamma = 0.7
tau = 0.01
num_epochs = 256
# fill replay memory such that batching is possible
fill_memory_2()
# Print params
printing_while_training = True
printing_while_testing = False
print('######################## Training ########################')
starting_time = datetime.datetime.now()
for i in range(num_epochs):
# random state
starting_state = torch.randint(-14, 15, (2,)).to(device).int()
# skip if terminal state
if starting_state[0].item() == 0 and starting_state[0].item() == 0:
continue
state = starting_state.clone()
# terminal state
terminal_state = torch.tensor([0, 0], device=device, dtype=torch.int)
iteration = 0
# get eps for exploring
eps = get_eps(i)
running_reward = 0.
# training loos
while True:
# compute action and next state
action = compute_action(actor_trained, state, eps)
next_state = compute_next_state(action, state)
# finished if next state is terminal state
if torch.eq(next_state, terminal_state).all():
reward = compute_reward_2(state, next_state, terminal_state)
running_reward += reward.item()
transition_ = Transition(state=state, action=action,
reward=reward, next_state=next_state)
replay_memory_.push(*transition_)
if printing_while_training:
print('{}: Finished after {} iterations with reward {} '
'in state {} starting from {}'
.format(i + 1, iteration + 1, running_reward,
next_state.data.numpy(),
starting_state.data.numpy()))
break
# abort if illegal state
elif torch.eq(next_state.abs(), 15).any() or iteration == 99:
reward = compute_reward_2(state, next_state, terminal_state)
running_reward += reward
transition_ = Transition(state=state, action=action,
reward=reward, next_state=next_state)
replay_memory_.push(*transition_)
if printing_while_training:
print('{}: Aborted after {} iterations with reward {} '
'in state {} starting from {}'
.format(i + 1, iteration + 1, running_reward,
next_state.data.numpy(),
starting_state.data.numpy()))
break
# compute immediate reward
reward = compute_reward_2(state, next_state, terminal_state)
# save it - only for logging purposes
running_reward += reward.item()
# construct transition
transition_ = Transition(state=state, action=action, reward=reward,
next_state=next_state)
# update model
update(transition_, replay_memory_, batch_size, gamma)
# perform soft updates
update_weights(actor_target, actor_trained, tau)
update_weights(critic_target, critic_trained, tau)
state = next_state
iteration += 1
print('Ended after: {}'.format(datetime.datetime.now() - starting_time))
print('######################## Testing ########################')
starting_time = datetime.datetime.now()
test_states = [torch.tensor([i, j], device=device, dtype=torch.int)
for i in range(-15, 14) for j in range(-15, 14)]
finished = 0
aborted = 0
aborted_reward = []
finished_reward = []
for starting_state in test_states:
state = starting_state.clone()
terminal_state = torch.tensor([0, 0], device=device, dtype=torch.int)
iteration = 0
reward = 0.
while True:
action = torch.argmax(actor_target(state.float())).view(1).int()
next_state = compute_next_state(action, state)
if torch.eq(next_state, terminal_state).all():
reward += compute_reward_2(state, next_state,
terminal_state)
finished_reward.append(reward.item())
if printing_while_testing:
print('{}: Finished after {} iterations with reward {} '
'in state {} starting from {}'
.format(starting_state.data.numpy(), iteration + 1,
reward.item(), next_state.data.numpy(),
starting_state.data.numpy()))
finished += 1
break
elif torch.eq(next_state.abs(), 15).any():
reward += compute_reward_2(state, next_state,
terminal_state)
aborted_reward.append(reward.item())
if printing_while_testing:
print('{}: Aborted after {} iterations with reward {} '
'in state {} starting from {}'
.format(starting_state.data.numpy(), iteration + 1,
reward.item(), next_state.data.numpy(),
starting_state.data.numpy()))
aborted += 1
break
elif iteration > 500:
if printing_while_testing:
print('Aborting due to more than 500 iterations! '
'Started from {}'.format(
starting_state.data.numpy()))
aborted += 1
break
reward += compute_reward_2(state, next_state, terminal_state)
state = next_state
iteration += 1
print('Ended after: {}'.format(datetime.datetime.now() - starting_time))
print('Finished: {}, aborted: {}'.format(finished, aborted))
print('Reward mean finished: {}, aborted: {}'
.format(np.mean(finished_reward), np.mean(aborted_reward)))
Я уже пытался использовать другую функцию вознаграждения, но она не имела никакого эффекта ....
Кроме того, я попытался использовать менее агрессивное исследование и optim.SGD
вместо optim.RMSprop
- оба не дали эффекта.