Я пытаюсь использовать PyTorch для распечатки точности предсказания каждого класса на основе официального учебного пособия ссылка
Но, похоже, что-то идет не так.Мой код намеревается выполнить эту работу следующим образом:
for epoch in range(num_epochs):
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
... (this is given by the tutorial)
(my code)
class_correct = list(0. for i in range(3))
class_total = list(0. for i in range(3))
for data in dataloaders['val']:
images, labels = data
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
c = (predicted == labels.data).squeeze()
for i in range(4):
label = labels.data[i]
class_correct[label] += c[i]
class_total[label] += 1
for i in range(3):
print('Accuracy of {} : {} / {} = {:.4f} %'.format(i,
class_correct[i], class_total[i], 100 * class_correct[i].item() /
class_total[i]))
print(file = f)
print()
Например, выходные данные эпохи 1/1:
Я думаю, чтодолжно быть выполнено следующее уравнение:
running_corrects: = 2 + 2
Но, как мне кажется, ничего не происходит.
Что там не так?Надеюсь, что кто-то может указать на мою ошибку и научить меня, как это сделать правильно.
Спасибо!