Удалить F.softmax
. Вы делаете log_softmax (softmax (x)).
CrossEntropyLoss
Этот критерий объединяет nn.LogSoftmax () и nn.NLLLoss () в одном классе .
import torch as t
class Net(t.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.emb = t.nn.Embedding(100, 8)
self.lstm = t.nn.LSTM(8, 16, batch_first=True)
self.linear = t.nn.Linear(16, 100)
def forward(self, x):
x = self.emb(x)
x, _ = self.lstm(x)
x = self.linear(x[:, -1])
#x = t.nn.Softmax(dim=1)(x)
return x
t.manual_seed(0)
net = Net()
batch_size = 1
X = t.LongTensor(batch_size, 5).random_(0, 100)
Y = t.LongTensor(batch_size).random_(0, 100)
optimizer = t.optim.Adam(net.parameters())
criterion = t.nn.CrossEntropyLoss()
for epoch in range(10):
optimizer.zero_grad()
output = net(X)
loss = criterion(output, Y)
loss.backward()
optimizer.step()
print(loss.item())
4.401515960693359
4.389760494232178
4.377873420715332
4.365848541259766
4.353675365447998
4.341339588165283
4.328824520111084
4.316114902496338
4.303196430206299
4.2900567054748535
С без комментариев t.nn.Softmax
:
4.602912902832031
4.6027679443359375
4.602619171142578
4.6024675369262695
4.602311611175537
4.602152347564697
4.601987361907959
4.601818084716797
4.6016435623168945
4.601463794708252
Использовать softmax во время оценки:
net.eval()
t.nn.Softmax(dim=1)(net(X[0].view(1,-1)))
tensor([[0.0088, 0.0121, 0.0098, 0.0072, 0.0085, 0.0083, 0.0083, 0.0108, 0.0127,
0.0090, 0.0094, 0.0082, 0.0099, 0.0115, 0.0094, 0.0107, 0.0081, 0.0096,
0.0087, 0.0131, 0.0129, 0.0127, 0.0118, 0.0107, 0.0087, 0.0073, 0.0114,
0.0076, 0.0103, 0.0112, 0.0104, 0.0077, 0.0116, 0.0091, 0.0091, 0.0104,
0.0106, 0.0094, 0.0116, 0.0091, 0.0117, 0.0118, 0.0106, 0.0113, 0.0083,
0.0091, 0.0076, 0.0089, 0.0076, 0.0120, 0.0107, 0.0139, 0.0097, 0.0124,
0.0096, 0.0097, 0.0104, 0.0128, 0.0084, 0.0119, 0.0096, 0.0100, 0.0073,
0.0099, 0.0086, 0.0090, 0.0089, 0.0098, 0.0102, 0.0086, 0.0115, 0.0110,
0.0078, 0.0097, 0.0115, 0.0102, 0.0103, 0.0107, 0.0095, 0.0083, 0.0090,
0.0120, 0.0085, 0.0113, 0.0128, 0.0074, 0.0096, 0.0123, 0.0106, 0.0105,
0.0101, 0.0112, 0.0086, 0.0105, 0.0121, 0.0103, 0.0075, 0.0098, 0.0082,
0.0093]], grad_fn=)