Модель PyTorch не учится - PullRequest
0 голосов
/ 14 июня 2019

Я пытаюсь написать код для задачи классификации собак и кошек , используя предварительно обученную модель resnet18 в PyTorch.Вот мой код для обучения:

from torchvision import models
model = models.resnet18(pretrained=True)
for param in model.parameters():
        param.requires_grad = True
model.classifier = nn.Sequential(nn.Linear(1024, 256),
                                     nn.ReLU(),
                                     nn.Dropout(0.2),
                                     nn.Linear(256, 2),
                                    )
train_loader = DataLoader(DogvsCatDataloader(dirname='data/train/train/', img_height=128, img_width=128, num_channels=3,
                                         num_classes=2, train_phase=True, transform=transformer), shuffle=True, batch_size=batch_size)

model.cuda()
model.train()
for epoch in range(epochs):
    tot_loss = 0
    correct = 0
    for data, target in T(train_loader):
        target = target.view(target.size(0))
        optimizer.zero_grad()
        pred = model(data)
        loss = criterion(pred, target)
        tot_loss += loss.data.cpu().numpy()
        loss.backward()
        optimizer.step()
        pred = pred.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum().numpy()
    print("Epoch: ", epoch, "Loss: ", tot_loss / len(train_loader.dataset), 'Acc: ',
          100.0*correct/len(train_loader.dataset))

Проблема в том, что окончательная потеря всегда остается около 0,7, а точность остается около 50% для этой задачи двоичной классификации, то есть она ничего не изучает.

Вот мой dataloader класс:

import torch
from torch.utils.data import Dataset
from torch.autograd import Variable
import os
from PIL import Image
import numpy as np
from torch.backends import cudnn


class  DogvsCatDataloader(Dataset):
    def __init__(self, dirname, img_width, img_height, num_channels, num_classes, train_phase=True, transform=None):
        super(Dataset, self).__init__()
        self.dirname = dirname
        self.images = os.listdir(self.dirname)
        self.len = len(self.images)
        self.img_width = img_width
        self.img_height = img_height
        self.num_channels = num_channels
        self.num_classes = num_classes
        self.train_phase = train_phase
        self.label = None
        self.transform = transform

    def __len__(self):
        return self.len

    def __getitem__(self, item):
        img = Image.open(os.path.join(self.dirname, self.images[item]))
        img = img.resize((self.img_width, self.img_height))
        # img = np.array(img).reshape(self.num_channels, self.img_width, self.img_height)
        if self.transform:
            img = self.transform(img)
        if self.train_phase:
            if 'cat' in self.images[item]:
                self.label = 0
            else:
                self.label = 1
            return Variable(torch.Tensor(img).cuda()), Variable(torch.LongTensor([self.label]).cuda())
        else:
            return Variable(torch.Tensor(img).cuda())

Что я делаю не так?

...