import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from random import randint
from matplotlib import pyplot as plt
train = datasets.MNIST("",train=True,download=True,
transform = transforms.Compose([transforms.ToTensor()]))
test = datasets.MNIST("",train=False,download=True,
transform = transforms.Compose([transforms.ToTensor()]))
bs = 64
trainset = torch.utils.data.DataLoader(train, batch_size = bs,
shuffle = True)
testset = torch.utils.data.DataLoader(train, batch_size = 1,
shuffle = True)
def MSE(a,b):
t1 = (a-b)**2
return t1.mean()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.mp = nn.MaxPool2d(2)
self.fc = nn.Linear(320, 1)
def forward(self, x):
in_size = x.size(0)
x = F.relu(self.mp(self.conv1(x)))
x = F.relu(self.mp(self.conv2(x)))
x = x.view(in_size, -1) # flatten the tensor
x = self.fc(x).relu()
return x
net = Net().cuda()
optimizer = optim.Adam(net.parameters(), lr=(1.0e-3))
print('net created')
losses=[]
for epoch in range(20):
net.train()#training mode
for data in trainset:
t1 = t.time()
x,y = data
x = x.cuda()
y = y.cuda()
optimizer.zero_grad()
output = net(x)
loss = MSE(output, y)
#print(round(float(output),2),float(y),round(float(loss),2))
print(float(loss))
loss.backward()
optimizer.step()
losses.append(float(loss))
CNN возвращает одно значение o / p, которое используется для вычисления ошибки согласно MSE. Следует отметить, что когда я изменяю net на слой 10 нейронов о / р, я получаю гораздо лучшие результаты. В этой ситуации я использую nll_loss в качестве функции потерь, а также применяю softmax к слою o / p. Теоретически я чувствую, что это должно работать. Однако это мой первый треск на CNN. Ошибка остается постоянной на уровне 7-9 для меня. Даже после тренировки на 200 партий я не вижу обучения. Любая информация о том, как это исправить, будет великолепна!