Недавно я сравнил две модели для DQN в среде CartPole-v0. Один из них представляет собой многослойный персептрон с 3 слоями, а другой - RNN, построенный из LSTM и 1 полностью связанного слоя. У меня есть буфер воспроизведения опыта размером 200000, и обучение не начинается, пока оно не будет заполнено.
Хотя MLP решил эту проблему за разумное количество этапов обучения (это означает достижение среднего вознаграждения в 195 за последние 100 эпизодов), модель RNN не могла сходиться так быстро, а ее максимальное среднее вознаграждение даже не достигло 195!
Я уже пытался увеличить размер партии, добавить больше нейронов в скрытое состояние LSTM, увеличить длину последовательности RNN и сделать полностью связанный слой более сложным - но каждая попытка не удалась, так как я видел огромные колебания среднего вознаграждения, поэтому модель почти не сошлись. Может быть, это поет раннего переоснащения?
class DQN(nn.Module):
def __init__(self, n_input, output_size, n_hidden, n_layers, dropout=0.3):
super(DQN, self).__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.lstm = nn.LSTM(input_size=n_input,
hidden_size=n_hidden,
num_layers=n_layers,
dropout=dropout,
batch_first=True)
self.dropout= nn.Dropout(dropout)
self.fully_connected = nn.Linear(n_hidden, output_size)
def forward(self, x, hidden_parameters):
batch_size = x.size(0)
output, hidden_state = self.lstm(x.float(), hidden_parameters)
seq_length = output.shape[1]
output1 = output.contiguous().view(-1, self.n_hidden)
output2 = self.dropout(output1)
output3 = self.fully_connected(output2)
new = output3.view(batch_size, seq_length, -1)
new = new[:, -1]
return new.float(), hidden_state
def init_hidden(self, batch_size, device):
weight = next(self.parameters()).data
hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().to(device),
weight.new(self.n_layers, batch_size, self.n_hidden).zero_().to(device))
return hidden
Вопреки тому, что я ожидал, более простая модель дала намного лучший результат, чем другая; хотя предполагается, что RNN лучше обрабатывает данные временных рядов.
Кто-нибудь может мне сказать, в чем причина этого?
Кроме того, я должен заявить, что я не применял разработку функций, и оба DQN работали с необработанными данными. Может ли RNN превзойти MLP при использовании нормализованных функций? (Я имею в виду кормление обеих моделей нормализованными данными)
Есть ли что-нибудь, что вы можете мне порекомендовать для повышения эффективности тренировок на RNN для достижения наилучших результатов?