Оценка точности нейронной сети после каждой эпохи - PullRequest
0 голосов
/ 11 ноября 2019
from dataset import get_strange_symbol_loader, get_strange_symbols_test_data
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class Net(nn.Module):
   def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(28*28, 512)
    self.fc2 = nn.Linear(512, 256)
    self.fc3 = nn.Linear(256, 15)

def forward(self,x):
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)

    return F.softmax(x, dim=1)


if __name__ == '__main__':
   net = Net()
   train, test = get_strange_symbol_loader(batch_size=128)
   loss_function = nn.CrossEntropyLoss()
   optimizer = optim.Adam(net.parameters(), lr=1e-3)
   Accuracy = []

   for epoch in range(30):   
       print("epoch",epoch)
       #Train
       for data in train:
           img, label = data  
           net.zero_grad()
           output = net(img.view(-1,28*28))
           loss = F.nll_loss(output, label)
           loss.backward()
           optimizer.step()
       #Test    
       correct, total = 0, 0
       with torch.no_grad():
          for data in test:
               img, label = data
               output = net(img.view(-1,784))
               for idx, i in enumerate(output):
                   if torch.argmax(i) == label[idx]:
                       correct += 1
                       total += 1
       Accuracy.append(round(correct/total, 3))
       print("Accuracy: ",Accuracy)

Вот моя нейронная сеть, созданная с помощью PyTorch на основе Sentdex . Я использую набор данных, предоставленный мне администратором моего университетского курса, импортированным функцией get_strange_symbol_loader(batch_size=128).

Когда я запускаю этот код, он говорит мне, что точность в каждой эпохе должна быть 1.0. Однако запуск блока #Test после итерации цикла for, содержащего эпоху, дает несколько более реалистичные результаты. Почему это происходит?

Моя цель состоит в том, чтобы построить график тестирования точности по количеству эпох, чтобы найти оптимальное количество эпох для модели до того, как она начнет перегоняться.

1 Ответ

1 голос
/ 11 ноября 2019

Вы увеличиваете correct и total в блоке

if torch.argmax(i) == label[idx]:
    correct += 1
    total += 1

, поэтому оба имеют всегда одно и то же значение, а одно, разделенное на другое, дает 1,0

Проверьте свои намерения, Я думаю, что удаление вкладки из total +=1 должно сделать это.

РЕДАКТИРОВАТЬ: я предполагаю, что "после запуска блока #test после ..." вы подразумеваете, что вы запускаете другой фрагмент, который может отличаться (возможно, правильно)

...