Моя цель - сохранять модель в каждую эпоху, так как я должен прекратить тренировку ночью, и я не хочу терять прогресс.
После того, как я тренировал свою модель в течение 1 эпохи, я прерывал процесс через терминал с CTRL + Z.
Когда я попытался возобновить тренировку, я получил эту ошибку
THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory
Traceback (most recent call last):
File "train.py", line 174, in <module>
train(train_loader, model, optimizer, epoch)
File "train.py", line 97, in train
loss1 = CE(atts, gts)
File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/loss.py", line 500, in forward
reduce=self.reduce)
File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/functional.py", line 1516, in binary_cross_entropy_with_logits
max_val = (-input).clamp(min=0)
RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu:58
Вот код, который управляет всем
import wandb
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import pdb, os, argparse
from datetime import datetime
from model.CPD_models import CPD_VGG
from model.CPD_ResNet_models import CPD_ResNet
from data import get_loader
from utils import clip_gradient, adjust_lr
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=10, help='epoch number')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--batchsize', type=int, default=1, help='training batch size')
parser.add_argument('--trainsize', type=int, default=352, help='training dataset size')
parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
parser.add_argument('--is_ResNet', type=bool, default=False, help='VGG or ResNet backbone')
parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
parser.add_argument('--decay_epoch', type=int, default=50, help='every n epochs decay learning rate')
parser.add_argument('--model_id', type=str, required=True, help='required unique id for trained model name')
parser.add_argument('--resume', type=str, default='', help='path to resume model training from checkpoint')
parser.add_argument('--wandb', type=bool, default=False, help='enable wandb tracking model training')
opt = parser.parse_args()
model_id = opt.model_id
WANDB_EN = opt.wandb
if WANDB_EN:
wandb.init(entity="albytree", project="cpd-train")
# Add all parsed config in one line
if WANDB_EN:
wandb.config.update(opt)
tot_epochs = opt.epoch
print("Training Info")
print("EPOCHS: {}".format(opt.epoch))
print("LEARNING RATE: {}".format(opt.lr))
print("BATCH SIZE: {}".format(opt.batchsize))
print("TRAIN SIZE: {}".format(opt.trainsize))
print("CLIP: {}".format(opt.clip))
print("USING ResNet backbone: {}".format(opt.is_ResNet))
print("DECAY RATE: {}".format(opt.decay_rate))
print("DECAY EPOCH: {}".format(opt.decay_epoch))
print("MODEL ID: {}".format(opt.model_id))
# build models
if opt.is_ResNet:
model = CPD_ResNet()
else:
model = CPD_VGG()
model.cuda()
params = model.parameters()
optimizer = torch.optim.Adam(params, opt.lr)
# If no previous training, 0 epochs passed
last_epoch = 0
resume_model_path = opt.resume;
if resume_model_path:
print("Loading previous trained model:"+resume_model_path)
checkpoint = torch.load(resume_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
last_epoch = checkpoint['epoch']
last_loss = checkpoint['loss']
dataset_name = 'ECSSD'
image_root = '../../DATASETS/TEST/'+dataset_name+'/im/'
gt_root = '../../DATASETS/TEST/'+dataset_name+'/gt/'
train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
total_step = len(train_loader)
print("Total step per epoch: {}".format(total_step))
CE = torch.nn.BCEWithLogitsLoss()
####################################################################################################
def train(train_loader, model, optimizer, epoch):
model.train()
for i, pack in enumerate(train_loader, start=1):
optimizer.zero_grad()
images, gts = pack
images = Variable(images)
gts = Variable(gts)
images = images.cuda()
gts = gts.cuda()
atts, dets = model(images)
loss1 = CE(atts, gts)
loss2 = CE(dets, gts)
loss = loss1 + loss2
loss.backward()
clip_gradient(optimizer, opt.clip)
optimizer.step()
if WANDB_EN:
wandb.log({'Loss': loss})
if i % 100 == 0 or i == total_step:
print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} Loss2: {:0.4f}'.
format(datetime.now(), epoch, opt.epoch, i, total_step, loss1.data, loss2.data))
# Save model and optimizer training data
trained_model_data = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss
}
if opt.is_ResNet:
save_path = 'models/CPD_Resnet/'
else:
save_path = 'models/CPD_VGG/'
if not os.path.exists(save_path):
print("Making trained model folder [{}]".format(save_path))
os.makedirs(save_path)
torch_model_ext = '.pth'
wandb_model_ext = '.h5'
model_unique_id = model_id+'_'+'ep'+'_'+'%d' % epoch
trained_model_name = 'CPD_train'
save_full_path_torch = save_path + trained_model_name + '_' + model_unique_id + torch_model_ext
save_full_path_wandb = save_path + trained_model_name + '_' + model_unique_id + wandb_model_ext
if os.path.exists(save_full_path_torch):
print("Torch model with name ["+save_full_path_torch+"] already exists!")
answ = raw_input("Do you want to replace it? [y/n] ")
if("y" in answ):
torch.save(trained_model_data, save_full_path_torch)
print("Saved torch model in "+save_full_path_torch)
else:
torch.save(trained_model_data, save_full_path_torch)
print("Saved torch model in "+save_full_path_torch)
if WANDB_EN:
if os.path.exists(save_full_path_wandb):
print("Wandb model with name ["+save_full_path_wandb+"] already exists!")
answ = raw_input("Do you want to replace it? [y/n] ")
if("y" in answ):
wandb.save(save_full_path_wandb)
print("Saved wandb model in "+save_full_path_wandb)
else:
wandb.save(save_full_path_wandb)
print("Saved wandb model in "+save_full_path_wandb)
####################################################################################################
print("Training on dataset: "+dataset_name)
print("Train images path: "+image_root)
print("Train gt path: "+gt_root)
print("Let's go!")
if WANDB_EN:
wandb.watch(model, log="all")
for epoch in range(last_epoch+1, tot_epochs+1):
adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
train(train_loader, model, optimizer, epoch)
print("TRAINING DONE!")
Кажется, что-то есть неправильно с потерей, но я не могу понять, в чем проблема.
РЕДАКТИРОВАНИЕ 1:
Я обучил модель для 2 эпох без ошибок, а затем прервал процесс.
Я также убил процесс, оставленный в памяти GPU.
После того, как я попытался возобновить модель, сохраненную в эпоха 1 и эпоха 2 , я получил ту же ошибку cuda, но в другой части кода
THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory
Traceback (most recent call last):
File "train.py", line 191, in <module>
train(train_loader, model, optimizer, epoch)
File "train.py", line 112, in train
atts, dets = model(images)
File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/albytree/TESI/CODICE/Workspace/ALGS/CPD/model/CPD_models.py", line 131, in forward
detection = self.agg2(x5_2, x4_2, x3_2)
File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/albytree/TESI/CODICE/Workspace/ALGS/CPD/model/CPD_models.py", line 86, in forward
x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu:58
Более того Я попытался протестировать сохраненную модель в эпоху 1 и эпоху 2 и получил эту ошибку
Traceback (most recent call last):
File "test.py", line 45, in <module>
model.load_state_dict(torch.load(opt.model_path))
File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/module.py", line 721, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CPD_VGG:
Missing key(s) in state_dict: "vgg.conv1.conv1_1.bias", "vgg.conv1.conv1_1.weight", "vgg.conv1.conv1_2.bias", "vgg.conv1.conv1_2.weight", "vgg.conv2.conv2_1.bias", "vgg.conv2.conv2_1.weight", "vgg.conv2.conv2_2.bias", "vgg.conv2.conv2_2.weight", "vgg.conv3.conv3_1.bias", "vgg.conv3.conv3_1.weight", "vgg.conv3.conv3_2.bias", "vgg.conv3.conv3_2.weight", "vgg.conv3.conv3_3.bias", "vgg.conv3.conv3_3.weight", "vgg.conv4_1.conv4_1_1.bias", "vgg.conv4_1.conv4_1_1.weight", "vgg.conv4_1.conv4_2_1.bias", "vgg.conv4_1.conv4_2_1.weight", "vgg.conv4_1.conv4_3_1.bias", "vgg.conv4_1.conv4_3_1.weight", "vgg.conv5_1.conv5_1_1.bias", "vgg.conv5_1.conv5_1_1.weight", "vgg.conv5_1.conv5_2_1.bias", "vgg.conv5_1.conv5_2_1.weight", "vgg.conv5_1.conv5_3_1.bias", "vgg.conv5_1.conv5_3_1.weight", "vgg.conv4_2.conv4_1_2.bias", "vgg.conv4_2.conv4_1_2.weight", "vgg.conv4_2.conv4_2_2.bias", "vgg.conv4_2.conv4_2_2.weight", "vgg.conv4_2.conv4_3_2.bias", "vgg.conv4_2.conv4_3_2.weight", "vgg.conv5_2.conv5_1_2.bias", "vgg.conv5_2.conv5_1_2.weight", "vgg.conv5_2.conv5_2_2.bias", "vgg.conv5_2.conv5_2_2.weight", "vgg.conv5_2.conv5_3_2.bias", "vgg.conv5_2.conv5_3_2.weight", "rfb3_1.branch0.0.bias", "rfb3_1.branch0.0.weight", "rfb3_1.branch1.0.bias", "rfb3_1.branch1.0.weight", "rfb3_1.branch1.1.bias", "rfb3_1.branch1.1.weight", "rfb3_1.branch1.2.bias", "rfb3_1.branch1.2.weight", "rfb3_1.branch1.3.bias", "rfb3_1.branch1.3.weight", "rfb3_1.branch2.0.bias", "rfb3_1.branch2.0.weight", "rfb3_1.branch2.1.bias", "rfb3_1.branch2.1.weight", "rfb3_1.branch2.2.bias", "rfb3_1.branch2.2.weight", "rfb3_1.branch2.3.bias", "rfb3_1.branch2.3.weight", "rfb3_1.branch3.0.bias", "rfb3_1.branch3.0.weight", "rfb3_1.branch3.1.bias", "rfb3_1.branch3.1.weight", "rfb3_1.branch3.2.bias", "rfb3_1.branch3.2.weight", "rfb3_1.branch3.3.bias", "rfb3_1.branch3.3.weight", "rfb3_1.conv_cat.bias", "rfb3_1.conv_cat.weight", "rfb3_1.conv_res.bias", "rfb3_1.conv_res.weight", "rfb4_1.branch0.0.bias", "rfb4_1.branch0.0.weight", "rfb4_1.branch1.0.bias", "rfb4_1.branch1.0.weight", "rfb4_1.branch1.1.bias", "rfb4_1.branch1.1.weight", "rfb4_1.branch1.2.bias", "rfb4_1.branch1.2.weight", "rfb4_1.branch1.3.bias", "rfb4_1.branch1.3.weight", "rfb4_1.branch2.0.bias", "rfb4_1.branch2.0.weight", "rfb4_1.branch2.1.bias", "rfb4_1.branch2.1.weight", "rfb4_1.branch2.2.bias", "rfb4_1.branch2.2.weight", "rfb4_1.branch2.3.bias", "rfb4_1.branch2.3.weight", "rfb4_1.branch3.0.bias", "rfb4_1.branch3.0.weight", "rfb4_1.branch3.1.bias", "rfb4_1.branch3.1.weight", "rfb4_1.branch3.2.bias", "rfb4_1.branch3.2.weight", "rfb4_1.branch3.3.bias", "rfb4_1.branch3.3.weight", "rfb4_1.conv_cat.bias", "rfb4_1.conv_cat.weight", "rfb4_1.conv_res.bias", "rfb4_1.conv_res.weight", "rfb5_1.branch0.0.bias", "rfb5_1.branch0.0.weight", "rfb5_1.branch1.0.bias", "rfb5_1.branch1.0.weight", "rfb5_1.branch1.1.bias", "rfb5_1.branch1.1.weight", "rfb5_1.branch1.2.bias", "rfb5_1.branch1.2.weight", "rfb5_1.branch1.3.bias", "rfb5_1.branch1.3.weight", "rfb5_1.branch2.0.bias", "rfb5_1.branch2.0.weight", "rfb5_1.branch2.1.bias", "rfb5_1.branch2.1.weight", "rfb5_1.branch2.2.bias", "rfb5_1.branch2.2.weight", "rfb5_1.branch2.3.bias", "rfb5_1.branch2.3.weight", "rfb5_1.branch3.0.bias", "rfb5_1.branch3.0.weight", "rfb5_1.branch3.1.bias", "rfb5_1.branch3.1.weight", "rfb5_1.branch3.2.bias", "rfb5_1.branch3.2.weight", "rfb5_1.branch3.3.bias", "rfb5_1.branch3.3.weight", "rfb5_1.conv_cat.bias", "rfb5_1.conv_cat.weight", "rfb5_1.conv_res.bias", "rfb5_1.conv_res.weight", "agg1.conv_upsample1.bias", "agg1.conv_upsample1.weight", "agg1.conv_upsample2.bias", "agg1.conv_upsample2.weight", "agg1.conv_upsample3.bias", "agg1.conv_upsample3.weight", "agg1.conv_upsample4.bias", "agg1.conv_upsample4.weight", "agg1.conv_upsample5.bias", "agg1.conv_upsample5.weight", "agg1.conv_concat2.bias", "agg1.conv_concat2.weight", "agg1.conv_concat3.bias", "agg1.conv_concat3.weight", "agg1.conv4.bias", "agg1.conv4.weight", "agg1.conv5.bias", "agg1.conv5.weight", "rfb3_2.branch0.0.bias", "rfb3_2.branch0.0.weight", "rfb3_2.branch1.0.bias", "rfb3_2.branch1.0.weight", "rfb3_2.branch1.1.bias", "rfb3_2.branch1.1.weight", "rfb3_2.branch1.2.bias", "rfb3_2.branch1.2.weight", "rfb3_2.branch1.3.bias", "rfb3_2.branch1.3.weight", "rfb3_2.branch2.0.bias", "rfb3_2.branch2.0.weight", "rfb3_2.branch2.1.bias", "rfb3_2.branch2.1.weight", "rfb3_2.branch2.2.bias", "rfb3_2.branch2.2.weight", "rfb3_2.branch2.3.bias", "rfb3_2.branch2.3.weight", "rfb3_2.branch3.0.bias", "rfb3_2.branch3.0.weight", "rfb3_2.branch3.1.bias", "rfb3_2.branch3.1.weight", "rfb3_2.branch3.2.bias", "rfb3_2.branch3.2.weight", "rfb3_2.branch3.3.bias", "rfb3_2.branch3.3.weight", "rfb3_2.conv_cat.bias", "rfb3_2.conv_cat.weight", "rfb3_2.conv_res.bias", "rfb3_2.conv_res.weight", "rfb4_2.branch0.0.bias", "rfb4_2.branch0.0.weight", "rfb4_2.branch1.0.bias", "rfb4_2.branch1.0.weight", "rfb4_2.branch1.1.bias", "rfb4_2.branch1.1.weight", "rfb4_2.branch1.2.bias", "rfb4_2.branch1.2.weight", "rfb4_2.branch1.3.bias", "rfb4_2.branch1.3.weight", "rfb4_2.branch2.0.bias", "rfb4_2.branch2.0.weight", "rfb4_2.branch2.1.bias", "rfb4_2.branch2.1.weight", "rfb4_2.branch2.2.bias", "rfb4_2.branch2.2.weight", "rfb4_2.branch2.3.bias", "rfb4_2.branch2.3.weight", "rfb4_2.branch3.0.bias", "rfb4_2.branch3.0.weight", "rfb4_2.branch3.1.bias", "rfb4_2.branch3.1.weight", "rfb4_2.branch3.2.bias", "rfb4_2.branch3.2.weight", "rfb4_2.branch3.3.bias", "rfb4_2.branch3.3.weight", "rfb4_2.conv_cat.bias", "rfb4_2.conv_cat.weight", "rfb4_2.conv_res.bias", "rfb4_2.conv_res.weight", "rfb5_2.branch0.0.bias", "rfb5_2.branch0.0.weight", "rfb5_2.branch1.0.bias", "rfb5_2.branch1.0.weight", "rfb5_2.branch1.1.bias", "rfb5_2.branch1.1.weight", "rfb5_2.branch1.2.bias", "rfb5_2.branch1.2.weight", "rfb5_2.branch1.3.bias", "rfb5_2.branch1.3.weight", "rfb5_2.branch2.0.bias", "rfb5_2.branch2.0.weight", "rfb5_2.branch2.1.bias", "rfb5_2.branch2.1.weight", "rfb5_2.branch2.2.bias", "rfb5_2.branch2.2.weight", "rfb5_2.branch2.3.bias", "rfb5_2.branch2.3.weight", "rfb5_2.branch3.0.bias", "rfb5_2.branch3.0.weight", "rfb5_2.branch3.1.bias", "rfb5_2.branch3.1.weight", "rfb5_2.branch3.2.bias", "rfb5_2.branch3.2.weight", "rfb5_2.branch3.3.bias", "rfb5_2.branch3.3.weight", "rfb5_2.conv_cat.bias", "rfb5_2.conv_cat.weight", "rfb5_2.conv_res.bias", "rfb5_2.conv_res.weight", "agg2.conv_upsample1.bias", "agg2.conv_upsample1.weight", "agg2.conv_upsample2.bias", "agg2.conv_upsample2.weight", "agg2.conv_upsample3.bias", "agg2.conv_upsample3.weight", "agg2.conv_upsample4.bias", "agg2.conv_upsample4.weight", "agg2.conv_upsample5.bias", "agg2.conv_upsample5.weight", "agg2.conv_concat2.bias", "agg2.conv_concat2.weight", "agg2.conv_concat3.bias", "agg2.conv_concat3.weight", "agg2.conv4.bias", "agg2.conv4.weight", "agg2.conv5.bias", "agg2.conv5.weight", "HA.gaussian_kernel".
Unexpected key(s) in state_dict: "loss", "optimizer_state_dict", "model_state_dict", "epoch".
Может быть, я не сохраняю состояния, как предполагалось?
Странно то, что перед добавлением учебного кода возобновления я просто сохранял модель в каждую эпоху только с torch.save(model.state_dict(), save_full_path_torch)
: мне удалось обучить модель в 10 эпохах, и она все еще работает во время тестирования.