RL - Вознаграждение падает после ~ 200 эпизодов - PullRequest
0 голосов
/ 12 апреля 2019

Я построил модель обучения подкреплению с помощью Pytorch. Я использовал принцип Q-Learning.

Модель хорошо работает до 200 эпизодов. Затем награда начинает падать. После этого начинают прыгать на отрицательной стадии. Вы можете увидеть это здесь: reward_2000_epoch

Я провел несколько гиперпараметрических тестов: эпсилон, коэффициент задержки, скорость обучения, гамма. Также в модели, которую я тестировал, с nn.Dropout или нет, также nn.Relu. Но ничто не дает лучшего результата, как на картинке выше. Иногда гораздо худший результат. Код модели:

class Policy(nn.Module):

    def __init__(self):
        super(Policy, self).__init__()
        self.state_space = env.observation_space.shape[0] #2
        self.action_space = env.action_space.n #3
        self.hidden = 50 
        self.l1 = nn.Linear(self.state_space, self.hidden, bias = False)
        self.l2 = nn.Linear(self.hidden, self.action_space, bias = False)

    def forward(self, x):
        model = torch.nn.Sequential(
            self.l1,
            nn.Dropout(p=0.6),
            nn.ReLU(),
            self.l2,
        )
        return model(x)

функция:

loss_fn = nn.MSELoss()
optimizer = optim.SGD(policy.parameters(), lr = learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 1, gamma=0.99)

while not done:
    'chose action'

    #Step forward and recive next state and reward
    Q1 = policy(Variable(torch.from_numpy(state).type(torch.FloatTensor)))
    maxQ1, _ = torch.max(Q1, -1)

    #Create target Q value for training the policy
    Q_target = Q.clone()
    Q_target = Variable(Q_target.data)
    Q_target[action] = reward + torch.mul(maxQ1.detach(), gamma)

    #Calculate loss 
    loss = loss_fn(Q, Q_target)

    #Update policy
    policy.zero_grad()
    loss.backward()
    optimizer.step()

    'find out if done'

    if done:
        scheduler.step()    

Что я могу сделать, чтобы модель училась даже после 200 эпизодов или даже держала награду?

...