Я пытаюсь потренироваться с LSTM и Pytorch.Я взял набор данных обзора фильмов IMDB , чтобы предсказать, будет ли обзор положительным или отрицательным.Я использую 80% набора данных для своих тренировок, убираю знаки препинания, использую GloVe
(с 200 оттенками) в качестве слоя для встраивания.
Перед тренировкой я также исключаю слишком короткие (обзоры длиной менее 50 символов) и слишком длинные (обзоры длиной более 1000 символов).
Для слоя LSTM
, который я используюhidden dimension 256
, num_layers 2
и one directional
параметры с 0.5 dropout
.После этого у меня есть полностью связанный слой.Для обучения я использовал функцию nn.BCELoss с оптимизатором Adam (lr=0.001
).
В настоящее время я получаю 85% точности проверки с 98% точностью обучения после 7 эпох.Я сделал следующие шаги для предотвращения переоснащения и получения более высокой точности:
- использовал weight_decay для оптимизатора Адама,
- пробовал SGD (lr = 0,1, 0,001) вместо Адама,
- пытался увеличить num_layers LSTM,
Во всех этих случаях модель не училась вообще, давая 50% точности для обучающих и проверочных наборов.
class CustomLSTM(nn.Module):
def __init__(self, vocab_size, use_embed=False, embed=None, embedding_size=200, hidden_size=256,
num_lstm_layers=2, bidirectional=False, dropout=0.5, output_dims=2):
super().__init__()
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.num_lstm_layers = num_lstm_layers
self.bidirectional = bidirectional
self.dropout = dropout
self.embedding = nn.Embedding(vocab_size, embedding_size)
if use_embed:
self.embedding.weight.data.copy_(torch.from_numpy(embed))
# self.embedding.requires_grad = False
self.lstm = nn.LSTM(input_size=embedding_size,
hidden_size=hidden_size,
num_layers=num_lstm_layers,
batch_first=True,
dropout=dropout,
bidirectional=bidirectional)
# print('output dims value ', output_dims)
self.drop_fc = nn.Dropout(0.5)
self.fc = nn.Linear(hidden_size, output_dims)
self.sig = nn.Sigmoid()
Я хочу понять:
- Почему модель не учится вообще с этими изменениями?
- Как повысить точность?