нейронная сеть pytorch (вероятно) не учится - PullRequest
0 голосов
/ 30 декабря 2018

Я пытаюсь написать агента DDPG, чтобы играть в футбольную игру в pytorch.Первоначально с агентом все в порядке (при наличии шума), но в процессе обучения (а шум уменьшается) сеть актеров просто выводит нули, делая игрока неподвижным.

Я проверил вывод сети с образцомвход и вроде нормально работает (не дает ноль).Есть ли какие-либо ошибки в Pytorch, которые могут вызвать это или это из-за ошибки в коде?Актерская сеть ::

class Actor(nn.Module):
def __init__(self, nb_states, nb_actions, hidden1=20, hidden2=30, init_w=5):
    super(Actor, self).__init__()
    self.fc1 = nn.Linear(nb_states, hidden1)
    self.fc2 = nn.Linear(hidden1, hidden2)
    self.fc3 = nn.Linear(hidden2, nb_actions)
    self.relu = nn.ReLU()
    self.tanh = nn.Tanh()
    self.init_weights(init_w)

def init_weights(self, init_w):
    self.fc1.weight.data = fanin_init(self.fc1.weight.data.size())
    self.fc2.weight.data = fanin_init(self.fc2.weight.data.size())
    self.fc3.weight.data.uniform_(-init_w, init_w)

def forward(self, x):
    out = self.fc1(x)
    out = self.relu(out)
    out = self.fc2(out)
    out = self.relu(out)
    out = self.fc3(out)
    out = self.tanh(out)
    return out

Обучение ::

def critic_train(self, s1, a1, r1, s2):
    a2 = self.trgt_actor.forward(s2).detach()
    next_val = torch.squeeze(self.trgt_critic.forward((s2, a2)).detach())

    y_expected = r1 + self.GAMMA * next_val

    y_predicted = torch.squeeze(self.critic.forward((s1, a1)))

    loss_critic = F.smooth_l1_loss(y_predicted, y_expected)

    self.critic_optim.zero_grad()

    loss_critic.backward()

    self.critic_optim.step()

    return None

def actor_train(self, s1, a1, r1, s2):
    pred_a1 = self.actor.forward(s1)

    loss_actor = -1 * torch.sum(self.critic.forward((s1, pred_a1)))

    self.actor_optim.zero_grad()

    loss_actor.backward()

    self.actor_optim.step()
    soft_update(self.trgt_actor, self.actor, 0.01)
    soft_update(self.trgt_critic, self.critic, 0.01)

    return None

Спасибо

...