Сеть не меняет веса во время тренировки, pytorch - PullRequest
0 голосов
/ 29 июня 2018

Я реализую DDPG и застрял, тренируя свои две сети.

Всего у меня есть 4 сети: actor, actor_target, критик и crit_target. Я обучаю актера и критика в цикле обучения и делаю мягкие обновления для двух других сетей:

def update_weights(self, source, tau):
        for target, source in zip(self.parameters(), source.parameters()):
            target.data.copy_(tau * source.data + (1 - tau) * target.data)

Мой тренировочный цикл выглядит так:

tensor_next_states = torch.tensor(next_states).view(-1, 1)
prediction_target = self.actor_target(tensor_next_states).data.numpy()
target_critic_output = self.critic_target(
    construct_tensor(next_states, prediction_target))
y = torch.tensor(rewards).view(-1,1) + \
    self.gamma * target_critic_output
output_critic = self.critic(
    torch.tensor(construct_tensor(states, actions), dtype=torch.float))

# compute loss and update critic
self.critic.zero_grad()
loss_critic = self.criterion_critic(y, output_critic)
loss_critic.backward()
self.critic_optim.step()

# compute loss and update actor
tensor_states = torch.tensor(states).view(-1, 1)
ouput_actor = self.actor(tensor_states).data.numpy()
self.actor.zero_grad()
loss_actor = (-1.) * \
             self.critic(construct_tensor(states, ouput_actor)).mean()
loss_actor.backward()
self.actor_optim.step()

# update target
self.actor_target.update_weights(self.actor, self.tau)
self.critic_target.update_weights(self.critic, self.tau)

с использованием SGD в качестве оптимизатора и self.criterion_critic = F.mse_loss.

construct_tensor(a,b) создает тензор, подобный [a[0], b[0], a[1], b[1], ...].

Я заметил, что RMSE на тестовом наборе до и после тренировки одинаков. Поэтому я много отлаживал и заметил в update_weights, что веса обученной сети и целевой сети одинаковы - поэтому я пришел к выводу, что обучение не влияет на веса обученной сети. Я уже проверил, что вычисленные потери не равны нулю, но все еще с плавающей запятой, проверил замену вызовов zero_grad() и перемещение вычисленных потерь в self, что не оказало никакого влияния.

Кто-нибудь уже встречал это поведение и / или есть какие-либо советы или знает, как это исправить?

Обновление: Полный код:

import datetime
import random
from collections import namedtuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


def combine_tensors(s, a):
    """
    Combines the two given tensors
    :param s: tensor1
    :param a: tensor2
    :return: combined tensor
    """
    target = []
    if not len(a[0].shape) == 0:
        for i in range(len(s)):
            target.append(torch.cat((s[i], a[i])).data.numpy())
    else:
        for i in range(len(s)):
            target.append(torch.cat((s[i], a[i].float().view(-1))) \
                          .data.numpy())
    return torch.tensor(target, device=device)


class actor(nn.Module):
    """
    Actor - gets a state (2-dim) and returns probabilities about which
    action to take (4 actions -> 4 outputs)
    """

    def __init__(self):
        super(actor, self).__init__()

        # define net structure
        self.input_layer = nn.Linear(2, 4)
        self.hidden_layer_1 = nn.Linear(4, 8)
        self.hidden_layer_2 = nn.Linear(8, 16)
        self.hidden_layer_3 = nn.Linear(16, 32)
        self.output_layer = nn.Linear(32, 4)

        # initialize them
        nn.init.xavier_uniform_(self.input_layer.weight)
        nn.init.xavier_uniform_(self.hidden_layer_1.weight)
        nn.init.xavier_uniform_(self.hidden_layer_2.weight)
        nn.init.xavier_uniform_(self.hidden_layer_3.weight)
        nn.init.xavier_uniform_(self.output_layer.weight)

        nn.init.constant_(self.input_layer.bias, 0.1)
        nn.init.constant_(self.hidden_layer_1.bias, 0.1)
        nn.init.constant_(self.hidden_layer_2.bias, 0.1)
        nn.init.constant_(self.hidden_layer_3.bias, 0.1)
        nn.init.constant_(self.output_layer.bias, 0.1)

    def forward(self, state):
        state = F.relu(self.input_layer(state))
        state = F.relu(self.hidden_layer_1(state))
        state = F.relu(self.hidden_layer_2(state))
        state = F.relu(self.hidden_layer_3(state))
        state = F.softmax(self.output_layer(state), dim=0)
        return state


class critic(nn.Module):
    """
    Critic - gets a state (2-dim) and an action and returns value
    """

    def __init__(self):
        super(critic, self).__init__()
        # define net structure
        self.input_layer = nn.Linear(3, 8)
        self.hidden_layer_1 = nn.Linear(8, 16)
        self.hidden_layer_2 = nn.Linear(16, 32)
        self.hidden_layer_3 = nn.Linear(32, 16)
        self.output_layer = nn.Linear(16, 1)

        # initialize them
        nn.init.xavier_uniform_(self.input_layer.weight)
        nn.init.xavier_uniform_(self.hidden_layer_1.weight)
        nn.init.xavier_uniform_(self.hidden_layer_2.weight)
        nn.init.xavier_uniform_(self.hidden_layer_3.weight)
        nn.init.xavier_uniform_(self.output_layer.weight)

        nn.init.constant_(self.input_layer.bias, 0.1)
        nn.init.constant_(self.hidden_layer_1.bias, 0.1)
        nn.init.constant_(self.hidden_layer_2.bias, 0.1)
        nn.init.constant_(self.hidden_layer_3.bias, 0.1)
        nn.init.constant_(self.output_layer.bias, 0.1)

    def forward(self, state_, action_):
        state_ = combine_tensors(state_, action_)
        state_ = F.relu(self.input_layer(state_))
        state_ = F.relu(self.hidden_layer_1(state_))
        state_ = F.relu(self.hidden_layer_2(state_))
        state_ = F.relu(self.hidden_layer_3(state_))
        state_ = self.output_layer(state_)
        return state_


Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):
    """
    Memory
    """

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


def compute_action(actor_trainined, state, eps=0.1):
    """
    Computes an action given the actual policy, the state and an eps.
    Eps is resposible for the amount of exploring
    :param actor_trainined: actual policy
    :param state:
    :param eps: float in [0,1]
    :return:
    """
    denoise = random.random()

    if denoise > eps:
        action_probs = actor_trainined(state.float())
        return torch.argmax(action_probs).view(1).int()
    else:
        return torch.randint(0, 4, (1,)).view(1).int()


def compute_next_state(_action, _state):
    """
    Computes the next state given an action and a state
    :param _action:
    :param _state:
    :return:
    """
    state_ = _state.clone()
    if _action.item() == 0:
        state_[1] += 1
    elif _action.item() == 1:
        state_[1] -= 1
    elif _action.item() == 2:
        state_[0] -= 1
    elif _action.item() == 3:
        state_[0] += 1

    return state_


def update_weights(target, source, tau):
    """
    Soft-Update of weights
    :param target:
    :param source:
    :param tau:
    :return:
    """
    for target, source in zip(target.parameters(), source.parameters()):
        target.data.copy_(tau * source.data + (1 - tau) * target.data)


def update(transition__, replay_memory, batch_size_, gamma_):
    """
    Performs one update step
    :param transition__:
    :param replay_memory:
    :param batch_size_:
    :param gamma_:
    :return:
    """
    replay_memory.push(*transition__)

    if replay_memory.__len__() < batch_size_:
        return

    transitions = replay_memory_.sample(batch_size_)
    batch = Transition(*zip(*transitions))

    states = torch.stack(batch.state)
    actions = torch.stack(batch.action)
    rewards = torch.stack(batch.reward)
    next_states = torch.stack(batch.next_state)

    action_target = torch.argmax(actor_target(next_states.float()), 1).int()
    y = (
            rewards.float().view(-1, 1) +
            gamma_ * critic_target(next_states.float(), action_target.float())
            .float()
    )

    critic_trained.zero_grad()
    crit_ = critic_trained(states.float(), actions.float())
    # nn stuff does not work here! -> doing mse myself..
    # loss_critic = (torch.sum((y.float() - crit_.float()) ** 2.)
    #                / y.data.nelement())
    loss_critic = F.l1_loss(y.float(), crit_.float())
    loss_critic.backward()
    optimizer_critic.step()

    actor_trained.zero_grad()
    loss_actor = ((-1.) * critic_trained(states.float(),
                                         torch.argmax(
                                             actor_trained(states.float()), 1
                                         ).int().float())).mean()
    loss_actor.backward()
    optimizer_actor.step()


def get_eps(epoch):
    """
    Computes the eps for action choosing dependant on the epoch
    :param epoch: number of epoch
    :return:
    """
    if epoch <= 10:
        eps_ = 1.
    elif epoch <= 20:
        eps_ = 0.8
    elif epoch <= 40:
        eps_ = 0.6
    elif epoch <= 60:
        eps_ = 0.4
    elif epoch <= 80:
        eps_ = 0.2
    else:
        eps_ = 0.1
    return eps_


def compute_reward_2(state_, next_state_, terminal_state_):
    """
    Better (?) reward function that "compute_reward"
    If next_state == terminal_state -> reward = 100
    If next_state illegal           -> reward = -100
    if next_state is further away from terminal_state than state_ -> -2
    else 1
    :param state_:
    :param next_state_:
    :param terminal_state_:
    :return:
    """
    if torch.eq(next_state_, terminal_state_).all():
        reward_ = 100
    elif torch.eq(next_state_.abs(), 15).any():
        reward_ = -100
    else:
        if (state_.abs() > next_state_.abs()).any():
            reward_ = 1.
        else:
            reward_ = -2
    return torch.tensor(reward_, device=device, dtype=torch.float)


def compute_reward(next_state_, terminal_state_):
    """
    Computes some reward
    :param next_state_:
    :param terminal_state_:
    :return:
    """
    if torch.eq(next_state_, terminal_state_).all():
        return torch.tensor(100., device=device, dtype=torch.float)
    elif next_state_[0] == 15 or next_state_[1] == 15:
        return torch.tensor(-100., device=device, dtype=torch.float)
    else:
        return (-1.) * next_state_.abs().sum().float()


def fill_memory_2():
    """
    Fills the memory with random transitions which got a "good" action chosen
    """
    terminal_state_ = torch.tensor([0, 0], device=device, dtype=torch.int)
    while replay_memory_.__len__() < batch_size:
        state_ = torch.randint(-4, 4, (2,)).to(device).int()
        if state_[0].item() == 0 and state_[1].item == 0:
            continue

        # try to find a "good" action
        if state_[0].item() == 0:
            if state_[1].item() > 0:
                action_ = torch.tensor(1, device=device, dtype=torch.int)
            else:
                action_ = torch.tensor(0, device=device, dtype=torch.int)
        elif state_[1].item() == 0:
            if state_[0].item() > 0:
                action_ = torch.tensor(2, device=device, dtype=torch.int)
            else:
                action_ = torch.tensor(3, device=device, dtype=torch.int)
        else:
            random_bit = random.random()
            if random_bit > 0.5:
                if state_[1].item() > 0:
                    action_ = torch.tensor(1, device=device, dtype=torch.int)
                else:
                    action_ = torch.tensor(0, device=device, dtype=torch.int)
            else:
                if state_[0].item() > 0:
                    action_ = torch.tensor(2, device=device, dtype=torch.int)
                else:
                    action_ = torch.tensor(3, device=device, dtype=torch.int)

        action_ = action_.view(1).int()
        next_state_ = compute_next_state(action_, state_)
        reward_ = compute_reward_2(state_, next_state_, terminal_state_)

        transition__ = Transition(state=state_, action=action_,
                                  reward=reward_, next_state=next_state_)
        replay_memory_.push(*transition__)


def fill_memory():
    """
    Fills the memory with random transitions
    """
    while replay_memory_.__len__() < batch_size:
        state_ = torch.randint(-14, 15, (2,)).to(device).int()
        if state_[0].item() == 0 and state_[1].item == 0:
            continue
        terminal_state_ = torch.tensor([0, 0], device=device, dtype=torch.int)
        action_ = torch.randint(0, 4, (1,)).view(1).int()
        next_state_ = compute_next_state(action_, state_)
        reward_ = compute_reward_2(state_, next_state_, terminal_state_)

        transition__ = Transition(state=state_, action=action_,
                                  reward=reward_, next_state=next_state_)
        replay_memory_.push(*transition__)


if __name__ == '__main__':
    # get device if possible
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # set seed
    seed_ = 0
    random.seed(seed_)  # seed of python
    if device == "cuda":
        # cuda seed
        torch.cuda.manual_seed(seed_)
    else:
        # cpu seed
        torch.manual_seed(seed_)

    # initialize the nets
    actor_trained = actor().to(device)
    actor_target = actor().to(device)
    # copy -> _trained eqaul _target
    actor_target.load_state_dict(actor_trained.state_dict())
    optimizer_actor = optim.RMSprop(actor_trained.parameters())
    # move them to the device
    critic_trained = critic().to(device)
    critic_target = critic().to(device)
    critic_target.load_state_dict((critic_trained.state_dict()))
    actor_target.load_state_dict((actor_trained.state_dict()))
    # used optimizer
    optimizer_critic = optim.RMSprop(critic_trained.parameters(),
                                     momentum=0.9, weight_decay=0.001)

    # replay memory
    capacity_replay_memory = 16384
    replay_memory_ = ReplayMemory(capacity_replay_memory)

    # hyperparams
    batch_size = 1024
    gamma = 0.7
    tau = 0.01
    num_epochs = 256

    # fill replay memory such that batching is possible
    fill_memory_2()

    # Print params
    printing_while_training = True
    printing_while_testing = False

    print('######################## Training ########################')
    starting_time = datetime.datetime.now()
    for i in range(num_epochs):
        # random state
        starting_state = torch.randint(-14, 15, (2,)).to(device).int()
        # skip if terminal state
        if starting_state[0].item() == 0 and starting_state[0].item() == 0:
            continue
        state = starting_state.clone()
        # terminal state
        terminal_state = torch.tensor([0, 0], device=device, dtype=torch.int)
        iteration = 0

        # get eps for exploring
        eps = get_eps(i)

        running_reward = 0.

        # training loos
        while True:
            # compute action and next state
            action = compute_action(actor_trained, state, eps)
            next_state = compute_next_state(action, state)

            # finished if next state is terminal state
            if torch.eq(next_state, terminal_state).all():
                reward = compute_reward_2(state, next_state, terminal_state)
                running_reward += reward.item()
                transition_ = Transition(state=state, action=action,
                                         reward=reward, next_state=next_state)
                replay_memory_.push(*transition_)
                if printing_while_training:
                    print('{}: Finished after {} iterations with reward {} '
                          'in state {} starting from {}'
                          .format(i + 1, iteration + 1, running_reward,
                                  next_state.data.numpy(),
                                  starting_state.data.numpy()))
                break
            # abort if illegal state
            elif torch.eq(next_state.abs(), 15).any() or iteration == 99:
                reward = compute_reward_2(state, next_state, terminal_state)
                running_reward += reward
                transition_ = Transition(state=state, action=action,
                                         reward=reward, next_state=next_state)
                replay_memory_.push(*transition_)
                if printing_while_training:
                    print('{}: Aborted after {} iterations with reward {} '
                          'in state {} starting from {}'
                          .format(i + 1, iteration + 1, running_reward,
                                  next_state.data.numpy(),
                                  starting_state.data.numpy()))
                break

            # compute immediate reward
            reward = compute_reward_2(state, next_state, terminal_state)
            # save it - only for logging purposes
            running_reward += reward.item()

            # construct transition
            transition_ = Transition(state=state, action=action, reward=reward,
                                     next_state=next_state)

            # update model
            update(transition_, replay_memory_, batch_size, gamma)
            # perform soft updates
            update_weights(actor_target, actor_trained, tau)
            update_weights(critic_target, critic_trained, tau)

            state = next_state
            iteration += 1
    print('Ended after: {}'.format(datetime.datetime.now() - starting_time))

    print('######################## Testing ########################')
    starting_time = datetime.datetime.now()
    test_states = [torch.tensor([i, j], device=device, dtype=torch.int)
                   for i in range(-15, 14) for j in range(-15, 14)]
    finished = 0
    aborted = 0
    aborted_reward = []
    finished_reward = []

    for starting_state in test_states:
        state = starting_state.clone()
        terminal_state = torch.tensor([0, 0], device=device, dtype=torch.int)
        iteration = 0
        reward = 0.

        while True:
            action = torch.argmax(actor_target(state.float())).view(1).int()
            next_state = compute_next_state(action, state)

            if torch.eq(next_state, terminal_state).all():
                reward += compute_reward_2(state, next_state,
                                           terminal_state)
                finished_reward.append(reward.item())
                if printing_while_testing:
                    print('{}: Finished after {} iterations with reward {} '
                          'in state {} starting from {}'
                          .format(starting_state.data.numpy(), iteration + 1,
                                  reward.item(), next_state.data.numpy(),
                                  starting_state.data.numpy()))
                finished += 1
                break
            elif torch.eq(next_state.abs(), 15).any():
                reward += compute_reward_2(state, next_state,
                                           terminal_state)
                aborted_reward.append(reward.item())
                if printing_while_testing:
                    print('{}: Aborted after {} iterations with reward {} '
                          'in state {} starting from {}'
                          .format(starting_state.data.numpy(), iteration + 1,
                                  reward.item(), next_state.data.numpy(),
                                  starting_state.data.numpy()))
                aborted += 1
                break
            elif iteration > 500:
                if printing_while_testing:
                    print('Aborting due to more than 500 iterations! '
                          'Started from {}'.format(
                        starting_state.data.numpy()))
                aborted += 1
                break
            reward += compute_reward_2(state, next_state, terminal_state)
            state = next_state
            iteration += 1

    print('Ended after: {}'.format(datetime.datetime.now() - starting_time))
    print('Finished: {}, aborted: {}'.format(finished, aborted))
    print('Reward mean finished: {}, aborted: {}'
          .format(np.mean(finished_reward), np.mean(aborted_reward)))

Я уже пытался использовать другую функцию вознаграждения, но она не имела никакого эффекта ....

Кроме того, я попытался использовать менее агрессивное исследование и optim.SGD вместо optim.RMSprop - оба не дали эффекта.

1 Ответ

0 голосов
/ 16 октября 2018

Теперь это может быть прямой ответ или рецепт для работы вашего кода, но у меня есть некоторые начальные опасения, которые могут помочь вам отладить код.

Самой большой проблемой, я полагаю, является то, что вы выполняете несколько преобразований в типы данных, которые не являются тензорами. Например, вы вызываете функции Combine_tensors () пару раз, и он преобразует данные тензоры в numpy () и создает новый тензор при возврате значения. В другой раз вы вызываете свои сети для выполнения прямого прохода и передаете им тензоры, преобразованные с помощью функции float () в качестве аргумента. Также есть вызовы int () на тензорах. Все эти вызовы приводят к потере графика тензора, который используется для вычисления градиента при вызове backward () . Это описано в документации PyTorch и должно быть понято до написания алгоритмов RL в этой среде. Крайне важно работать с тензорами все время, пока вы находитесь в функции поезда - с момента преобразования пакета опыта в тензоры и до вызова обратных функций.

Одно это еще не гарантирует, что обучение будет выполнено правильно. Например, когда вы используете целевые сети для оценки потерь для критика, вы должны отсоединить результаты, чтобы предотвратить вычисления градиента в целевых сетях (хотя, если вы используете оптимизатор и регистрируете только параметры критика, это скорее проблема с производительностью: step () вызов не будет обновлять параметры целевой сети).

Когда обе проблемы будут решены в вашем коде, вы можете наблюдать более правильное поведение. Мой дополнительный комментарий заключается в том, что я не совсем понимаю части вашего кода и считаю, что это неправильная реализация DDPG (то есть вы используете argmax () для выходных данных сети актера и предоставляете это критичной сети и это не похоже на правильный путь).

Я бы посоветовал вам сделать шаг назад и немного лучше понять структуру и идеи PyTorch, а также поискать некоторые базовые реализации DDPG, чтобы убедиться, что вы знаете, как выполнять вычисления шаг за шагом.

...