Как использовать PyTorch, чтобы распечатать точность прогноза каждого класса? - PullRequest
0 голосов
/ 15 мая 2018

Я пытаюсь использовать PyTorch для распечатки точности предсказания каждого класса на основе официального учебного пособия ссылка

Но, похоже, что-то идет не так.Мой код намеревается выполнить эту работу следующим образом:

    for epoch in range(num_epochs):

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
       ... (this is given by the tutorial)

    (my code)

    class_correct = list(0. for i in range(3))
    class_total = list(0. for i in range(3))

    for data in dataloaders['val']:
        images, labels = data
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        c = (predicted == labels.data).squeeze()

        for i in range(4):
            label = labels.data[i]
            class_correct[label] += c[i]
            class_total[label] += 1

    for i in range(3):
        print('Accuracy of {} : {} / {} = {:.4f} %'.format(i, 
class_correct[i], class_total[i], 100 * class_correct[i].item() / 
class_total[i]))

    print(file = f)
    print()

Например, выходные данные эпохи 1/1: enter image description here

Я думаю, чтодолжно быть выполнено следующее уравнение:

running_corrects: = 2 + 2

Но, как мне кажется, ничего не происходит.

Что там не так?Надеюсь, что кто-то может указать на мою ошибку и научить меня, как это сделать правильно.

Спасибо!

1 Ответ

0 голосов
/ 16 мая 2018

Наконец-то я решил эту проблему. Сначала я сравнил параметры двух моделей и выяснил, что они одинаковы. Поэтому я подтвердил, что модель такая же. А потом я проверил два входа и с удивлением обнаружил, что они разные.

Таким образом, я внимательно изучил входные данные двух моделей, и ответ состоял в том, что аргументы, переданные второй модели, обновили , а не .

Код:

for data in dataloaders['val']:
    images, labels = data
    outputs = model(inputs)

Изменить на:

for data in dataloaders['val']:
    inputs, labels = data
    outputs = model(inputs)

Готово! * * 1013

...