Я реализую простую сеть RNN, которая предсказывает 1/0 для некоторых данных временных рядов переменной длины.Сеть сначала будет передавать данные обучения в ячейку LSTM, а затем использовать линейный уровень для классификации.
Обычно мы используем мини-пакеты для обучения сети.Но проблема в том, что эта простая сеть RNN не обучается, когда я использую batch_size
> 1.
Мне удается создать минимальный пример кода, который может воспроизвести проблему.Если вы установите batch_size=1
в строке 95, сеть будет успешно обучена, но если вы установите batch_size=2
, сеть вообще не будет обучаться, а потери просто подпрыгивают.(требуется python3, pytorch> = 0.4.0)
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
class ToyDataLoader(object):
def __init__(self, batch_size):
self.batch_size = batch_size
self.index = 0
self.dataset_size = 10
# generate 10 random variable length training samples,
# each time step has 1 feature dimension
self.X = [
[[1], [1], [1], [1], [0], [0], [1], [1], [1]],
[[1], [1], [1], [1]],
[[0], [0], [1], [1]],
[[1], [1], [1], [1], [1], [1], [1]],
[[1], [1]],
[[0]],
[[0], [0], [0], [0], [0], [0], [0]],
[[1]],
[[0], [1]],
[[1], [0]]
]
# assign labels for the toy traning set
self.y = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
def __len__(self):
return self.dataset_size // self.batch_size
def __iter__(self):
return self
def __next__(self):
if self.index + self.batch_size > self.dataset_size:
self.index = 0
raise StopIteration()
if self.index == 0: # shufle the dataset
tmp = list(zip(self.X, self.y))
random.shuffle(tmp)
self.X, self.y = zip(*tmp)
self.y = torch.LongTensor(self.y)
X = self.X[self.index: self.index + self.batch_size]
y = self.y[self.index: self.index + self.batch_size]
self.index += self.batch_size
return X, y
class NaiveRNN(nn.Module):
def __init__(self):
super(NaiveRNN, self).__init__()
self.lstm = nn.LSTM(1, 128)
self.linear = nn.Linear(128, 2)
def forward(self, X):
'''
Parameter:
X: list containing variable length training data
'''
# get the length of each seq in the batch
seq_lengths = [len(x) for x in X]
# convert to torch.Tensor
seq_tensor = [torch.Tensor(seq) for seq in X]
# sort seq_lengths and seq_tensor based on seq_lengths, required by torch.nn.utils.rnn.pad_sequence
pairs = sorted(zip(seq_lengths, seq_tensor),
key=lambda pair: pair[0], reverse=True)
seq_lengths = torch.LongTensor([pair[0] for pair in pairs])
seq_tensor = [pair[1] for pair in pairs]
# padded_seq shape: (seq_len, batch_size, feature_size)
padded_seq = pad_sequence(seq_tensor)
# pack them up
packed_seq = pack_padded_sequence(padded_seq, seq_lengths.numpy())
# feed to rnn
packed_output, (ht, ct) = self.lstm(packed_seq)
# linear classification layer
y_pred = self.linear(ht[-1])
return y_pred
def main():
trainloader = ToyDataLoader(batch_size=2) # not training at all! !!
# trainloader = ToyDataLoader(batch_size=1) # it converges !!!
model = NaiveRNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
for epoch in range(30):
# switch to train mode
model.train()
for i, (X, labels) in enumerate(trainloader):
# compute output
outputs = model(X)
loss = criterion(outputs, labels)
# measure accuracy and record loss
_, predicted = torch.max(outputs, 1)
accu = (predicted == labels).sum().item() / labels.shape[0]
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch: [{}][{}/{}]\tLoss {:.4f}\tAccu {:.3f}'.format(
epoch, i, len(trainloader), loss, accu))
if __name__ == '__main__':
main()
Пример вывода при batch_size=1
:
...
Epoch: [28][7/10] Loss 0.1582 Accu 1.000
Epoch: [28][8/10] Loss 0.2718 Accu 1.000
Epoch: [28][9/10] Loss 0.0000 Accu 1.000
Epoch: [29][0/10] Loss 0.2808 Accu 1.000
Epoch: [29][1/10] Loss 0.0000 Accu 1.000
Epoch: [29][2/10] Loss 0.0001 Accu 1.000
Epoch: [29][3/10] Loss 0.0149 Accu 1.000
Epoch: [29][4/10] Loss 0.1445 Accu 1.000
Epoch: [29][5/10] Loss 0.2866 Accu 1.000
Epoch: [29][6/10] Loss 0.0170 Accu 1.000
Epoch: [29][7/10] Loss 0.0869 Accu 1.000
Epoch: [29][8/10] Loss 0.0000 Accu 1.000
Epoch: [29][9/10] Loss 0.0498 Accu 1.000
Пример вывода при batch_size=2
:
...
Epoch: [27][2/5] Loss 0.8051 Accu 0.000
Epoch: [27][3/5] Loss 1.2835 Accu 0.000
Epoch: [27][4/5] Loss 1.0782 Accu 0.000
Epoch: [28][0/5] Loss 0.5201 Accu 1.000
Epoch: [28][1/5] Loss 0.6587 Accu 0.500
Epoch: [28][2/5] Loss 0.3488 Accu 1.000
Epoch: [28][3/5] Loss 0.5413 Accu 0.500
Epoch: [28][4/5] Loss 0.6769 Accu 0.500
Epoch: [29][0/5] Loss 1.0434 Accu 0.000
Epoch: [29][1/5] Loss 0.4460 Accu 1.000
Epoch: [29][2/5] Loss 0.9879 Accu 0.000
Epoch: [29][3/5] Loss 1.0784 Accu 0.500
Epoch: [29][4/5] Loss 0.6051 Accu 1.000
Я искал много материалов и до сих пор не могу понять, почему.