Точность резко меняется при переходе от CPU к GPU - PullRequest
1 голос
/ 13 июля 2020

Итак, я работал над данными классификации меланомы SIIM на Kaggle. Обучая несколько сетей, я обнаружил, что когда я тренировал их на CPU, точность казалась подходящей, около 0,75. При переключении на графический процессор точность будет колебаться в пределах 0,5 или около того. Что мне с этим делать? Вот фрагмент кода тренировки l oop. В итоге обученная модель была resnext50.

import cv2
device = "cpu"
import torch.nn.functional as F
epochs=3
#model = torch.load("model.pt")
model.cpu()
#model.cuda()
print("======== Training for ", epochs, "epochs=============")
for epoch in range(epochs):
    total_loss = 0
    model.train()
    print("Training.......")
    print("======== EPOCH #",epoch,"=================")
    tmp_acc = 0
    for i,batch in enumerate(train_loader):
        
        img,label = batch["images"],batch["labels"]
        #img = img.permute(0,3,1,2)
        #img = torch.Tensor(img)
        label = label.type(torch.FloatTensor)
        
        img,label = img.to(device),label.to(device)
        
        model.zero_grad()
        
        op = model(img)
        
        label_cpu = label.cpu().numpy()
        
        op = F.sigmoid(op)
        
        
        output = op.detach().cpu().numpy()
        tmp_acc += accuracy_score(output,label_cpu)
        loss = criterion(op,label)
        total_loss = loss.item()
        loss.backward()
        
        adam.step()
        if(i%10==0 and i>0):
            print("STEP: ",i, "of steps ",len(train_loader))
        
            print("Current loss: ",total_loss/i)
            print("Training Accuracy ",tmp_acc/i)
            
        
        
    avg_loss = total_loss/len(train_loader)
    print("The loss after ",epoch," epochs is ",avg_loss)
    print("OP",op)
    print("Label",label_cpu)
        
    

            
        
            
torch.save(model.state_dict(),"/kaggle/working/model.pt")    
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...