Я пытаюсь обучить модель DenseNet на множестве патчей.В моем наборе данных есть 20 тыс. Изображений для обучения и 6 тыс. Для оценки.
Это моя основная функция:
if __name__ == "__main__":
for epoch in range(START_EPOCH, START_EPOCH+hp.epoch):
adjust_learning_rate(optimizer,epoch)
train(epoch, hp.wrong_save)
#mining(epoch)
valid(epoch)
Я настраиваю скорость обучения в каждую эпоху из-за снижения скорости обучения, я запускаю свойтренировка, которая загружает мой trainloader
в графический процессор, а затем я запускаю функцию проверки на моем наборе проверки для проверки некоторых метрик в каждой эпохе.
У меня есть Nvidia GTX1060 (6 ГБ), и я могу загрузить поезд без проблем, но когда я загружаю valloader
, я получаю:
RuntimeError: CUDA out of memory. Tried to allocate 74.12 MiB (GPU 0; 5.93 GiB total capacity; 4.73 GiB already allocated; 75.06 MiB free; 19.57 MiB cached)
Я хотел быузнать, что вы думаете о лучшем подходе здесь.Можно ли загрузить мой оценочный набор не в графический процессор и запустить оценку?Есть ли обходной путь?
Есть мой поезд и допустимые функции:
# Optimization, Loss Function Init
criterion = nn.BCELoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=hp.momentum, weight_decay=hp.weight_decay)
def train(epoch, wrong_save=False):
''' trian net using patches of slide.
save csv file that has patch file name predicted incorrectly.
Args:
epoch (int): current epoch
wrong_save (bool): If True, save the csv file that has patch file name
predicted incorrectly
'''
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
wrong_list = []
for batch_idx, (inputs, targets, filename) in enumerate(trainloader):
if USE_CUDA:
inputs = inputs.cuda()
targets = torch.FloatTensor(np.array(targets).astype(float)).cuda()
optimizer.zero_grad()
inputs, targets = Variable(inputs), Variable(targets)
outputs = net(inputs)
outputs = torch.squeeze(outputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
total += targets.size(0)
batch_size = targets.shape[0]
outputs += Variable((torch.ones(batch_size) * (THRESHOLD)).cuda())
outputs = torch.floor(outputs)
correct += outputs.data.eq(targets.data).cpu().sum()
filename_list = filename
if wrong_save == True:
for idx in range(len(filename_list)):
if outputs.data[idx] != targets.data[idx]:
wrong_name = filename_list[idx]
wrong_list.append(wrong_name)
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
if wrong_save == True:
wrong_csv = open(cf.wrong_path+'wrong_data_epoch'+str(epoch)+'.csv','w',encoding='utf-8')
wr = csv.writer(wrong_csv)
for name in wrong_list:
wr.writerow([name])
wrong_csv.close()
CUR_TRA_ACC.append(100.*correct/total)
def valid(epoch):
''' valid net using patches of slide.
Save checkpoint if AUC score is higher than saved checkpoint's.
Args:
epoch (int): current epoch
'''
global BEST_AUC
global THRESHOLD
global LR_CHANCE
global CK_CHANCE
global LR_DECAY
net.eval()
valid_loss = 0
total = 0
correct = 0
outputs_list = np.array([])
targets_list = np.array([])
for batch_idx, (inputs, targets) in enumerate(valloader):
if USE_CUDA:
inputs = inputs.cuda()
targets = torch.FloatTensor(np.array(targets).astype(float)).cuda()
batch_size = targets.shape[0]
inputs, targets = Variable(inputs, volatile=True), Variable(targets)
outputs = net(inputs)
total += targets.size(0)
outputs = torch.squeeze(outputs)
loss = criterion(outputs, targets)
valid_loss += loss.item()
_outputs = np.array(outputs.data.cpu()).astype(float)
_targets = np.array(targets.data.cpu()).astype(float)
outputs_list = np.append(outputs_list, _outputs)
targets_list = np.append(targets_list, _targets)
outputs += Variable((torch.ones(batch_size) * (1-THRESHOLD)).cuda())
outputs = torch.floor(outputs)
correct += int(outputs.eq(targets).cpu().sum())
progress_bar(batch_idx, len(valloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (valid_loss/(batch_idx+1), 100.*correct/total, correct, total))
print(tp, tn, fp, fn)
correct, tp, tn, fp, fn, recall, precision, specificity, f1_score, auc, threshold = stats(outputs_list, targets_list)
acc = correct/total
THRESHOLD = threshold
print(tp, tn, fp, fn)
print('Acc: %.3f, Recall: %.3f, Prec: %.3f, Spec: %.3f, F1: %.3f, Thres: %.3f, AUC: %.3f'
%(acc, recall, precision, specificity, f1_score, threshold, auc))
print('%17s %12s\n%-11s %-8d %-8d\n%-11s %-8d %-8d'
%('Tumor', 'Normal','pos',tp,fp,'neg',fn,tn))
print("lr: ",args.lr * (0.5 ** (LR_DECAY)), "lr chance:",LR_CHANCE)
# plot data
CUR_EPOCH.append(epoch)
CUR_VAL_ACC.append(acc)
CUR_LOSS.append(valid_loss/(batch_idx+1))
CUR_LR.append(args.lr * (0.5 ** (LR_DECAY)))
# Save checkpoint.
if auc > BEST_AUC:
print('saving...')
BEST_AUC = auc
state = {
'net': net if USE_CUDA else net,
'acc': acc,
'loss': valid_loss,
'recall': recall,
'specificity': specificity,
'precision': precision,
'f1_score': f1_score,
'auc': auc,
'epoch': epoch,
'lr': args.lr * (0.5**(LR_DECAY)),
'threshold': threshold
}
torch.save(state, './checkpoint/ckpt.t7')
И вот как я строю свой набор данных:
def get_dataset(train_transform, test_transform, train_max,
val_max, subtest_max, ratio=0, mining_mode=False):
''' dataset function to get train, valid, subtest, test, mining dataset
Args:
train_transform (torchvision.transforms): train set transform for data argumentation
test_transform (torchvision.transfroms): test set transform for data argumentation
train_max (int): limit of trian set
val_max (int): limit of validation set
subtest_max (int): limit of subtest set
ratio (int): for mining_mode, inclusion ratio of train set compared mining set
mining_mode (bool): If true, return mining dataset
'''
train_dataset = camel(cf.dataset_path + 'train/', usage='train',
limit = train_max, transform=train_transform)
val_dataset = camel(cf.dataset_path + 'validation/', usage='val',
limit = val_max, transform=test_transform)
subtest_dataset = camel(cf.dataset_path + 'test/', usage='subtest',
limit = subtest_max, transform=test_transform)
test_dataset = camel(cf.test_path, usage ='test',transform=test_transform)
if mining_mode == True:
mining_dataset = camel(cf.dataset_path + 'mining/', usage='mining',
train_ratio = ratio, transform=train_transform)
return train_dataset, val_dataset, subtest_dataset, test_dataset, mining_dataset
else:
return train_dataset, val_dataset, subtest_dataset, test_dataset