Я предполагаю, что вы хотите рассчитать точность для случая мультикласса (так что ваши классы имеют форму [0, 1, 2, 3, ..., N]
).
Вы используете maximum
, тогда как это должно быть argmax
, например:
def accuracy(predictions, labels):
classes = torch.argmax(predictions, dim=1)
return torch.mean((classes == labels).float())
Этот индекс возврата с максимальным значением, в то время как вы возвращаете максимальную вероятность (или ненормализованную вероятность). Поскольку вероятность почти никогда не равна 1
, вы всегда должны иметь нулевую точность (до float
точности, иногда она может быть действительно близкой к 0
или 1
, поэтому она может "попасть").
Например, 0.9 != 2
, и вы никогда не сможете предсказать класс 2
, но вы можете случайно предсказать класс 1
(0.999999999 ~= 1
).
Эта функция должна вызываться внутри вашего внутреннего l oop, точно так же, как вы рассчитываете потери, поэтому это будет:
for epoch in range(epochs):
running_loss = 0.00
running_accuracy = 0.00
for i, data in enumerate(trainloader, 0):
inputs, labels = data inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = cnn(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_accuracy += accuracy(outputs, labels)
running_loss /= len(trainloader)
running_accuracy /= len(trainloader)
То же самое будет go для тестирования и проверки, просто переключите вашу модель в режим оценки (через model.eval()
) и отключить градиент с помощью with torch.no_grad():
диспетчера контекста).