Когда я тренировал свою UNet -модель в pytorch, я обнаружил, что игральные кости и мимика больше, чем результат теста. подробности, подобные этому:
Epoch 19/999
----------
LR 0.001
train: bce: 0.053801, dice: 0.055995, loss: 0.054898, miou: 0.893865
val: bce: 0.098117, dice: 0.109000, loss: 0.103558, miou: 0.808214
EarlyStopping counter: 8 out of 8
Early stopping
Best val loss: 0.100456
test: bce: 0.273259, dice: 0.379796, loss: 0.326527, miou: 0.385055
Я пробовал некоторые решения, найденные в Интернете, но ситуация по-прежнему не изменилась, я пытался использовать такие решения:
- Убедитесь, что определенный слой BN используется только один раз
- Не используйте слой BN
- увеличение набора данных
- с использованием L1 / L2 в модели
Вот мой код:
def double_conv(in_channels, out_channels_1, out_channels_2):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels_1, 3, padding=1),
nn.BatchNorm2d(num_features=out_channels_1, eps=1e-5, momentum=0.1, affine=True),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels_1, out_channels_2, 3, padding=1),
nn.BatchNorm2d(num_features=out_channels_2, eps=1e-5, momentum=0.1, affine=True),
nn.ReLU(inplace=True)
)
class UNet_one(nn.Module):
def __init__(self, n_class, cfg=[64, 64, 128, 128, 64, 64]):
super().__init__()
self.n_class = n_class
self.cfg = cfg
self.dconv_down1 = double_conv(1, self.cfg[0], self.cfg[1])
self.dconv_down2 = double_conv(self.cfg[1], self.cfg[2], self.cfg[3])
self.maxpool = nn.MaxPool2d(2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.dconv_up1 = double_conv(self.cfg[1] + self.cfg[3], self.cfg[4], self.cfg[5])
self.conv_last = nn.Conv2d(self.cfg[5], n_class, 1)
def forward(self, x):
conv1 = self.dconv_down1(x)
x = self.maxpool(conv1)
x = self.dconv_down2(x)
x = self.upsample(x)
x = torch.cat([x, conv1], dim=1)
x = self.dconv_up1(x)
out = self.conv_last(x)
return out
class training():
def __init__(self, args, model):
self.args = args
# train_dataset filepath
self.dir_img = args.train_imgdir
self.dir_mask = args.train_maskdir
self.aug = args.train_aug
self.aug_num = 1000
self.save = args.save #the path of saving model
#parameters
self.epochs = args.epochs
self.train_batch_size = args.train_batchsize
self.val_batch_size = args.val_batchsize
self.test_batch_size = args.test_batchsize
self.patch_h_w = args.patch_h_w
self.N_patches = args.N_patches
self.spilt_ratio = args.spilt_ratio
self.model = model
def dataset_preparation(self):
train_set, val_set = get_spilt_trainDataset(self.dir_img, self.dir_mask, spilt_ratio=self.spilt_ratio, patch_h=self.patch_h_w,
patch_w=self.patch_h_w, N_patches=self.N_patches, inside=True)
train_set = SimDataset(train_set[0], train_set[1])
val_set = SimDataset(val_set[0], val_set[1])
dataloaders = {
'train': DataLoader(train_set, batch_size=self.train_batch_size, shuffle=True, num_workers=8),
'val': DataLoader(val_set, batch_size=self.val_batch_size, shuffle=True, num_workers=8)
}
return dataloaders
def train_model(self, dataloaders, optimizer, scheduler):
best_loss = 1e10
early_stopping = EarlyStopping(patience=8, verbose=True)
for epoch in range(self.epochs):
print('Epoch {}/{}'.format(epoch, self.epochs - 1))
print('-' * 10)
since = time.time()
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
for param_group in optimizer.param_groups:
print("LR", param_group['lr'])
self.model.train(mode=True).float() # Set model to training mode
else:
self.model.eval() # Set model to evaluate mode
metrics = defaultdict(float)
epoch_samples = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device).float()
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = self.model(inputs)
loss = calc_loss_entorpy(outputs, labels, metrics)
if phase == 'train':
loss.backward()
optimizer.step()
epoch_samples += inputs.size(0)
print_metrics(metrics, epoch_samples, phase)
epoch_loss = metrics['loss'] / epoch_samples
if phase == 'val':
val_loss_epoch = epoch_loss
if epoch_loss < best_loss:
best_loss = epoch_loss
print("saving best model")
torch.save(self.model.state_dict(), self.save)
#EarlyStopping Judgement
early_stopping(val_loss_epoch, self.model)
if early_stopping.early_stop:
print("Early stopping")
break
time_elapsed = time.time() - since
print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val loss: {:4f}'.format(best_loss))
return self.model
def start(self):
# Training
dataloaders = self.dataset_preparation()
optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=1e-3, weight_decay=1e-5)
# optimizer_ft = optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=1e-3, momentum=0.9, dampening=0, weight_decay=0)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1)
model = self.train_model(dataloaders= dataloaders, optimizer=optimizer_ft, scheduler=exp_lr_scheduler)
return model
if __name__ == '__main__':
model = UNet_one(n_class=1).to(device)
args = get_args()
if args.load:
model.load_state_dict(torch.load(args.model))
my_training = training(args, model)
my_training.start()
Мой тестовый код:
class examination(object):
def __init__(self, model, args):
self.args = args
self.model = model
self.batchsize = self.args.test_batchsize
self.metrics = defaultdict(float)
self.aug = self.args.test_aug
self.aug_test_num = 1000
self.patch_h_w = args.patch_h_w
self.patch_imgs, self.patch_targets = None, None
self.test_border_masks = get_full_test_broder_mask()
def dataset_prepared(self):
#Patches
self.patch_imgs, self.patch_targets = get_examinationDataset(self.args.test_imgdir, self.args.test_maskdir,
patch_h=self.patch_h_w, patch_w=self.patch_h_w)
data_consistency_check(self.patch_imgs, self.patch_targets)
# Make Dataset
test_dataset = SimDataset(self.patch_imgs, self.patch_targets)
test_loader = DataLoader(test_dataset, batch_size=self.batchsize, shuffle=True, num_workers=8)
return test_loader
def visual(self, inputs, labels, outputs, count):
inputs = inputs.cpu().numpy()
labels = labels.cpu().numpy()
outputs = outputs.cpu().numpy()
# print('outputs:', outputs.shape)
outputs = np.transpose(outputs, (0, 2, 3, 1))
inputs = np.transpose(inputs, (0, 2, 3, 1))
labels = np.transpose(labels, (0, 2, 3, 1))
x = 0
for i in range(x, x+outputs.shape[0]):
ori_img = inputs[i]
ori_img = ori_img * 255
gt_img = labels[i]
gt_img = gt_img * 255
output_img = outputs[i]
output_img = output_img*255
cv2.imwrite('../outcome/result_ori/ori_'+str(i*count)+'.jpg', ori_img)#
cv2.imwrite('../outcome/result_label/label_'+str(i*count)+'.jpg', gt_img)
cv2.imwrite('../outcome/result_out/output_'+str(i*count)+'.jpg', output_img)
x +=outputs.shape[0]
def start(self):
self.model.eval() # Set model to the evaluation mode
test_loader = self.dataset_prepared()
epoch_samples = 0
count = 1
output_sum = np.zeros((1620, 1, self.patch_h_w, self.patch_h_w))
pos=0
for inputs, labels in test_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = self.model(inputs.float())
calc_testmetrics(outputs, labels, self.metrics)
epoch_samples += inputs.size(0)
#display
self.visual(inputs, labels, outputs, count)
count +=1
print_metrics(self.metrics, epoch_samples=epoch_samples, phase='test')
if __name__ == '__main__':
import multiprocessing as mp
with torch.no_grad():
args = get_args()
model = UNet_one(n_class=1).to(device)
# model = UNet_one_BN(n_class=1).to(device)
model.load_state_dict(torch.load(args.test_model))
my_predict = examination(model=model, args=args)
p = mp.Process(target=my_predict.start())
p.start()
p.join()
Код, который я использую каждый BN только один раз:
class UNet_one_BN(nn.Module):
def __init__(self, n_class, cfg=[64, 64, 128, 128, 64, 64]):
super().__init__()
self.n_class = n_class
self.cfg = cfg
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
#first block
self.conv1_0 = nn.Conv2d(1, self.cfg[0], 3, padding=1)
self.bn1_0 = nn.BatchNorm2d(self.cfg[0])
self.conv1_1 = nn.Conv2d(self.cfg[0], self.cfg[1], 3, padding=1)
self.bn1_1 = nn.BatchNorm2d(self.cfg[1])
#second block
self.conv2_0 = nn.Conv2d(self.cfg[1], self.cfg[2], 3, padding=1)
self.bn2_0 = nn.BatchNorm2d(self.cfg[2])
self.conv2_1 = nn.Conv2d(self.cfg[2], self.cfg[3], 3, padding=1)
self.bn2_1 = nn.BatchNorm2d(self.cfg[3])
#third block
self.conv3_0 = nn.Conv2d(self.cfg[1] + self.cfg[3], self.cfg[4], 3, padding=1)
self.bn3_0 = nn.BatchNorm2d(self.cfg[4])
self.conv3_1 = nn.Conv2d(self.cfg[4], self.cfg[5], 3, padding=1)
self.bn3_1 = nn.BatchNorm2d(self.cfg[5])
#last block
self.conv_last = nn.Conv2d(self.cfg[5], n_class, 1)
def forward(self, x):
#first block
conv1_0 = self.conv1_0(x)
bn1_0 = self.bn1_0(conv1_0)
relu1_0 = self.relu(bn1_0)
conv1_1 = self.conv1_1(relu1_0)
bn1_1 = self.bn1_1(conv1_1)
relu1_1 = self.relu(bn1_1)
mxpool = self.maxpool(relu1_1)
#second block
conv2_0 = self.conv2_0(mxpool)
bn2_0 = self.bn2_0(conv2_0)
relu2_0 = self.relu(bn2_0)
conv2_1 = self.conv2_1(relu2_0)
bn2_1 = self.bn2_1(conv2_1)
relu2_1 = self.relu(bn2_1)
upsample = self.upsample(relu2_1)
concat = torch.cat([upsample, relu1_1], dim=1)
#third block
conv3_0 = self.conv3_0(concat)
bn3_0 = self.bn3_0(conv3_0)
relu3_0 = self.relu(bn3_0)
conv3_1 = self.conv3_1(relu3_0)
bn3_1 = self.bn3_1(conv3_1)
relu3_1 = self.relu(bn3_1)
#output
out = self.conv_last(relu3_1)
return out