Возобновление обучения модели pytorch вызывает ошибку «CUDA вне памяти» - PullRequest
0 голосов
/ 29 апреля 2020

Моя цель - сохранять модель в каждую эпоху, так как я должен прекратить тренировку ночью, и я не хочу терять прогресс.
После того, как я тренировал свою модель в течение 1 эпохи, я прерывал процесс через терминал с CTRL + Z.
Когда я попытался возобновить тренировку, я получил эту ошибку

THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory
Traceback (most recent call last):
  File "train.py", line 174, in <module>
    train(train_loader, model, optimizer, epoch)
  File "train.py", line 97, in train
    loss1 = CE(atts, gts)
  File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/loss.py", line 500, in forward
    reduce=self.reduce)
  File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/functional.py", line 1516, in binary_cross_entropy_with_logits
    max_val = (-input).clamp(min=0)
RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu:58

Вот код, который управляет всем

import wandb
import torch
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
import pdb, os, argparse
from datetime import datetime

from model.CPD_models import CPD_VGG
from model.CPD_ResNet_models import CPD_ResNet
from data import get_loader
from utils import clip_gradient, adjust_lr


parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=10, help='epoch number')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--batchsize', type=int, default=1, help='training batch size')
parser.add_argument('--trainsize', type=int, default=352, help='training dataset size')
parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
parser.add_argument('--is_ResNet', type=bool, default=False, help='VGG or ResNet backbone')
parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
parser.add_argument('--decay_epoch', type=int, default=50, help='every n epochs decay learning rate')
parser.add_argument('--model_id', type=str, required=True, help='required unique id for trained model name')
parser.add_argument('--resume', type=str, default='', help='path to resume model training from checkpoint')
parser.add_argument('--wandb', type=bool, default=False, help='enable wandb tracking model training')
opt = parser.parse_args()

model_id = opt.model_id
WANDB_EN = opt.wandb
if WANDB_EN:
    wandb.init(entity="albytree", project="cpd-train")

# Add all parsed config in one line
if WANDB_EN:
    wandb.config.update(opt)
tot_epochs = opt.epoch
print("Training Info")
print("EPOCHS: {}".format(opt.epoch))
print("LEARNING RATE: {}".format(opt.lr))
print("BATCH SIZE: {}".format(opt.batchsize))
print("TRAIN SIZE: {}".format(opt.trainsize))
print("CLIP: {}".format(opt.clip))
print("USING ResNet backbone: {}".format(opt.is_ResNet))
print("DECAY RATE: {}".format(opt.decay_rate))
print("DECAY EPOCH: {}".format(opt.decay_epoch))
print("MODEL ID: {}".format(opt.model_id))

# build models
if opt.is_ResNet:
    model = CPD_ResNet()
else:
    model = CPD_VGG()

model.cuda()
params = model.parameters()
optimizer = torch.optim.Adam(params, opt.lr)
# If no previous training, 0 epochs passed
last_epoch = 0
resume_model_path = opt.resume;
if resume_model_path:
    print("Loading previous trained model:"+resume_model_path)
    checkpoint = torch.load(resume_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    last_epoch = checkpoint['epoch']
    last_loss = checkpoint['loss']

dataset_name = 'ECSSD'
image_root = '../../DATASETS/TEST/'+dataset_name+'/im/'
gt_root = '../../DATASETS/TEST/'+dataset_name+'/gt/'
train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
total_step = len(train_loader)
print("Total step per epoch: {}".format(total_step))

CE = torch.nn.BCEWithLogitsLoss()

####################################################################################################

def train(train_loader, model, optimizer, epoch):
    model.train()
    for i, pack in enumerate(train_loader, start=1):
        optimizer.zero_grad()
        images, gts = pack
        images = Variable(images)
        gts = Variable(gts)
        images = images.cuda()
        gts = gts.cuda()

        atts, dets = model(images)
        loss1 = CE(atts, gts)
        loss2 = CE(dets, gts)
        loss = loss1 + loss2
        loss.backward()

        clip_gradient(optimizer, opt.clip)
        optimizer.step()
        if WANDB_EN:
            wandb.log({'Loss': loss})
        if i % 100 == 0 or i == total_step:
            print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} Loss2: {:0.4f}'.
                  format(datetime.now(), epoch, opt.epoch, i, total_step, loss1.data, loss2.data))

    # Save model and optimizer training data
    trained_model_data = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss
    }

    if opt.is_ResNet:
        save_path = 'models/CPD_Resnet/'
    else:
        save_path = 'models/CPD_VGG/'

    if not os.path.exists(save_path):
        print("Making trained model folder [{}]".format(save_path))
        os.makedirs(save_path)

    torch_model_ext = '.pth'
    wandb_model_ext = '.h5'
    model_unique_id = model_id+'_'+'ep'+'_'+'%d' % epoch
    trained_model_name = 'CPD_train' 
    save_full_path_torch = save_path + trained_model_name + '_' + model_unique_id + torch_model_ext 
    save_full_path_wandb = save_path + trained_model_name + '_' + model_unique_id + wandb_model_ext
    if os.path.exists(save_full_path_torch):
        print("Torch model with name ["+save_full_path_torch+"] already exists!")
        answ = raw_input("Do you want to replace it? [y/n] ")
        if("y" in answ):
            torch.save(trained_model_data, save_full_path_torch) 
            print("Saved torch model in "+save_full_path_torch)
    else:
            torch.save(trained_model_data, save_full_path_torch) 
            print("Saved torch model in "+save_full_path_torch)

    if WANDB_EN:
        if os.path.exists(save_full_path_wandb):    
            print("Wandb model with name ["+save_full_path_wandb+"] already exists!")
            answ = raw_input("Do you want to replace it? [y/n] ")
            if("y" in answ):
                wandb.save(save_full_path_wandb)
                print("Saved wandb model in "+save_full_path_wandb)
        else:
                wandb.save(save_full_path_wandb)
                print("Saved wandb model in "+save_full_path_wandb)


####################################################################################################

print("Training on dataset: "+dataset_name)
print("Train images path: "+image_root)
print("Train gt path: "+gt_root)
print("Let's go!")

if WANDB_EN:
    wandb.watch(model, log="all")
for epoch in range(last_epoch+1, tot_epochs+1):
    adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
    train(train_loader, model, optimizer, epoch)
print("TRAINING DONE!")

Кажется, что-то есть неправильно с потерей, но я не могу понять, в чем проблема.

РЕДАКТИРОВАНИЕ 1:

Я обучил модель для 2 эпох без ошибок, а затем прервал процесс.
Я также убил процесс, оставленный в памяти GPU.
После того, как я попытался возобновить модель, сохраненную в эпоха 1 и эпоха 2 , я получил ту же ошибку cuda, но в другой части кода

THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory
Traceback (most recent call last):
  File "train.py", line 191, in <module>
    train(train_loader, model, optimizer, epoch)
  File "train.py", line 112, in train
    atts, dets = model(images)
  File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/albytree/TESI/CODICE/Workspace/ALGS/CPD/model/CPD_models.py", line 131, in forward
    detection = self.agg2(x5_2, x4_2, x3_2)
  File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/albytree/TESI/CODICE/Workspace/ALGS/CPD/model/CPD_models.py", line 86, in forward
    x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu:58

Более того Я попытался протестировать сохраненную модель в эпоху 1 и эпоху 2 и получил эту ошибку

Traceback (most recent call last):
  File "test.py", line 45, in <module>
    model.load_state_dict(torch.load(opt.model_path))
  File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/module.py", line 721, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CPD_VGG:
    Missing key(s) in state_dict: "vgg.conv1.conv1_1.bias", "vgg.conv1.conv1_1.weight", "vgg.conv1.conv1_2.bias", "vgg.conv1.conv1_2.weight", "vgg.conv2.conv2_1.bias", "vgg.conv2.conv2_1.weight", "vgg.conv2.conv2_2.bias", "vgg.conv2.conv2_2.weight", "vgg.conv3.conv3_1.bias", "vgg.conv3.conv3_1.weight", "vgg.conv3.conv3_2.bias", "vgg.conv3.conv3_2.weight", "vgg.conv3.conv3_3.bias", "vgg.conv3.conv3_3.weight", "vgg.conv4_1.conv4_1_1.bias", "vgg.conv4_1.conv4_1_1.weight", "vgg.conv4_1.conv4_2_1.bias", "vgg.conv4_1.conv4_2_1.weight", "vgg.conv4_1.conv4_3_1.bias", "vgg.conv4_1.conv4_3_1.weight", "vgg.conv5_1.conv5_1_1.bias", "vgg.conv5_1.conv5_1_1.weight", "vgg.conv5_1.conv5_2_1.bias", "vgg.conv5_1.conv5_2_1.weight", "vgg.conv5_1.conv5_3_1.bias", "vgg.conv5_1.conv5_3_1.weight", "vgg.conv4_2.conv4_1_2.bias", "vgg.conv4_2.conv4_1_2.weight", "vgg.conv4_2.conv4_2_2.bias", "vgg.conv4_2.conv4_2_2.weight", "vgg.conv4_2.conv4_3_2.bias", "vgg.conv4_2.conv4_3_2.weight", "vgg.conv5_2.conv5_1_2.bias", "vgg.conv5_2.conv5_1_2.weight", "vgg.conv5_2.conv5_2_2.bias", "vgg.conv5_2.conv5_2_2.weight", "vgg.conv5_2.conv5_3_2.bias", "vgg.conv5_2.conv5_3_2.weight", "rfb3_1.branch0.0.bias", "rfb3_1.branch0.0.weight", "rfb3_1.branch1.0.bias", "rfb3_1.branch1.0.weight", "rfb3_1.branch1.1.bias", "rfb3_1.branch1.1.weight", "rfb3_1.branch1.2.bias", "rfb3_1.branch1.2.weight", "rfb3_1.branch1.3.bias", "rfb3_1.branch1.3.weight", "rfb3_1.branch2.0.bias", "rfb3_1.branch2.0.weight", "rfb3_1.branch2.1.bias", "rfb3_1.branch2.1.weight", "rfb3_1.branch2.2.bias", "rfb3_1.branch2.2.weight", "rfb3_1.branch2.3.bias", "rfb3_1.branch2.3.weight", "rfb3_1.branch3.0.bias", "rfb3_1.branch3.0.weight", "rfb3_1.branch3.1.bias", "rfb3_1.branch3.1.weight", "rfb3_1.branch3.2.bias", "rfb3_1.branch3.2.weight", "rfb3_1.branch3.3.bias", "rfb3_1.branch3.3.weight", "rfb3_1.conv_cat.bias", "rfb3_1.conv_cat.weight", "rfb3_1.conv_res.bias", "rfb3_1.conv_res.weight", "rfb4_1.branch0.0.bias", "rfb4_1.branch0.0.weight", "rfb4_1.branch1.0.bias", "rfb4_1.branch1.0.weight", "rfb4_1.branch1.1.bias", "rfb4_1.branch1.1.weight", "rfb4_1.branch1.2.bias", "rfb4_1.branch1.2.weight", "rfb4_1.branch1.3.bias", "rfb4_1.branch1.3.weight", "rfb4_1.branch2.0.bias", "rfb4_1.branch2.0.weight", "rfb4_1.branch2.1.bias", "rfb4_1.branch2.1.weight", "rfb4_1.branch2.2.bias", "rfb4_1.branch2.2.weight", "rfb4_1.branch2.3.bias", "rfb4_1.branch2.3.weight", "rfb4_1.branch3.0.bias", "rfb4_1.branch3.0.weight", "rfb4_1.branch3.1.bias", "rfb4_1.branch3.1.weight", "rfb4_1.branch3.2.bias", "rfb4_1.branch3.2.weight", "rfb4_1.branch3.3.bias", "rfb4_1.branch3.3.weight", "rfb4_1.conv_cat.bias", "rfb4_1.conv_cat.weight", "rfb4_1.conv_res.bias", "rfb4_1.conv_res.weight", "rfb5_1.branch0.0.bias", "rfb5_1.branch0.0.weight", "rfb5_1.branch1.0.bias", "rfb5_1.branch1.0.weight", "rfb5_1.branch1.1.bias", "rfb5_1.branch1.1.weight", "rfb5_1.branch1.2.bias", "rfb5_1.branch1.2.weight", "rfb5_1.branch1.3.bias", "rfb5_1.branch1.3.weight", "rfb5_1.branch2.0.bias", "rfb5_1.branch2.0.weight", "rfb5_1.branch2.1.bias", "rfb5_1.branch2.1.weight", "rfb5_1.branch2.2.bias", "rfb5_1.branch2.2.weight", "rfb5_1.branch2.3.bias", "rfb5_1.branch2.3.weight", "rfb5_1.branch3.0.bias", "rfb5_1.branch3.0.weight", "rfb5_1.branch3.1.bias", "rfb5_1.branch3.1.weight", "rfb5_1.branch3.2.bias", "rfb5_1.branch3.2.weight", "rfb5_1.branch3.3.bias", "rfb5_1.branch3.3.weight", "rfb5_1.conv_cat.bias", "rfb5_1.conv_cat.weight", "rfb5_1.conv_res.bias", "rfb5_1.conv_res.weight", "agg1.conv_upsample1.bias", "agg1.conv_upsample1.weight", "agg1.conv_upsample2.bias", "agg1.conv_upsample2.weight", "agg1.conv_upsample3.bias", "agg1.conv_upsample3.weight", "agg1.conv_upsample4.bias", "agg1.conv_upsample4.weight", "agg1.conv_upsample5.bias", "agg1.conv_upsample5.weight", "agg1.conv_concat2.bias", "agg1.conv_concat2.weight", "agg1.conv_concat3.bias", "agg1.conv_concat3.weight", "agg1.conv4.bias", "agg1.conv4.weight", "agg1.conv5.bias", "agg1.conv5.weight", "rfb3_2.branch0.0.bias", "rfb3_2.branch0.0.weight", "rfb3_2.branch1.0.bias", "rfb3_2.branch1.0.weight", "rfb3_2.branch1.1.bias", "rfb3_2.branch1.1.weight", "rfb3_2.branch1.2.bias", "rfb3_2.branch1.2.weight", "rfb3_2.branch1.3.bias", "rfb3_2.branch1.3.weight", "rfb3_2.branch2.0.bias", "rfb3_2.branch2.0.weight", "rfb3_2.branch2.1.bias", "rfb3_2.branch2.1.weight", "rfb3_2.branch2.2.bias", "rfb3_2.branch2.2.weight", "rfb3_2.branch2.3.bias", "rfb3_2.branch2.3.weight", "rfb3_2.branch3.0.bias", "rfb3_2.branch3.0.weight", "rfb3_2.branch3.1.bias", "rfb3_2.branch3.1.weight", "rfb3_2.branch3.2.bias", "rfb3_2.branch3.2.weight", "rfb3_2.branch3.3.bias", "rfb3_2.branch3.3.weight", "rfb3_2.conv_cat.bias", "rfb3_2.conv_cat.weight", "rfb3_2.conv_res.bias", "rfb3_2.conv_res.weight", "rfb4_2.branch0.0.bias", "rfb4_2.branch0.0.weight", "rfb4_2.branch1.0.bias", "rfb4_2.branch1.0.weight", "rfb4_2.branch1.1.bias", "rfb4_2.branch1.1.weight", "rfb4_2.branch1.2.bias", "rfb4_2.branch1.2.weight", "rfb4_2.branch1.3.bias", "rfb4_2.branch1.3.weight", "rfb4_2.branch2.0.bias", "rfb4_2.branch2.0.weight", "rfb4_2.branch2.1.bias", "rfb4_2.branch2.1.weight", "rfb4_2.branch2.2.bias", "rfb4_2.branch2.2.weight", "rfb4_2.branch2.3.bias", "rfb4_2.branch2.3.weight", "rfb4_2.branch3.0.bias", "rfb4_2.branch3.0.weight", "rfb4_2.branch3.1.bias", "rfb4_2.branch3.1.weight", "rfb4_2.branch3.2.bias", "rfb4_2.branch3.2.weight", "rfb4_2.branch3.3.bias", "rfb4_2.branch3.3.weight", "rfb4_2.conv_cat.bias", "rfb4_2.conv_cat.weight", "rfb4_2.conv_res.bias", "rfb4_2.conv_res.weight", "rfb5_2.branch0.0.bias", "rfb5_2.branch0.0.weight", "rfb5_2.branch1.0.bias", "rfb5_2.branch1.0.weight", "rfb5_2.branch1.1.bias", "rfb5_2.branch1.1.weight", "rfb5_2.branch1.2.bias", "rfb5_2.branch1.2.weight", "rfb5_2.branch1.3.bias", "rfb5_2.branch1.3.weight", "rfb5_2.branch2.0.bias", "rfb5_2.branch2.0.weight", "rfb5_2.branch2.1.bias", "rfb5_2.branch2.1.weight", "rfb5_2.branch2.2.bias", "rfb5_2.branch2.2.weight", "rfb5_2.branch2.3.bias", "rfb5_2.branch2.3.weight", "rfb5_2.branch3.0.bias", "rfb5_2.branch3.0.weight", "rfb5_2.branch3.1.bias", "rfb5_2.branch3.1.weight", "rfb5_2.branch3.2.bias", "rfb5_2.branch3.2.weight", "rfb5_2.branch3.3.bias", "rfb5_2.branch3.3.weight", "rfb5_2.conv_cat.bias", "rfb5_2.conv_cat.weight", "rfb5_2.conv_res.bias", "rfb5_2.conv_res.weight", "agg2.conv_upsample1.bias", "agg2.conv_upsample1.weight", "agg2.conv_upsample2.bias", "agg2.conv_upsample2.weight", "agg2.conv_upsample3.bias", "agg2.conv_upsample3.weight", "agg2.conv_upsample4.bias", "agg2.conv_upsample4.weight", "agg2.conv_upsample5.bias", "agg2.conv_upsample5.weight", "agg2.conv_concat2.bias", "agg2.conv_concat2.weight", "agg2.conv_concat3.bias", "agg2.conv_concat3.weight", "agg2.conv4.bias", "agg2.conv4.weight", "agg2.conv5.bias", "agg2.conv5.weight", "HA.gaussian_kernel". 
    Unexpected key(s) in state_dict: "loss", "optimizer_state_dict", "model_state_dict", "epoch".

Может быть, я не сохраняю состояния, как предполагалось?
Странно то, что перед добавлением учебного кода возобновления я просто сохранял модель в каждую эпоху только с torch.save(model.state_dict(), save_full_path_torch): мне удалось обучить модель в 10 эпохах, и она все еще работает во время тестирования.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...