Pytorch- UNet кубик / миоу в тесте ниже, чем кубик / миоу в поезде и видации - PullRequest
0 голосов
/ 07 января 2020

Когда я тренировал свою UNet -модель в pytorch, я обнаружил, что игральные кости и мимика больше, чем результат теста. подробности, подобные этому:

Epoch 19/999
----------
LR 0.001
train: bce: 0.053801, dice: 0.055995, loss: 0.054898, miou: 0.893865
val: bce: 0.098117, dice: 0.109000, loss: 0.103558, miou: 0.808214
EarlyStopping counter: 8 out of 8
Early stopping
Best val loss: 0.100456
test: bce: 0.273259, dice: 0.379796, loss: 0.326527, miou: 0.385055

Я пробовал некоторые решения, найденные в Интернете, но ситуация по-прежнему не изменилась, я пытался использовать такие решения:

  1. Убедитесь, что определенный слой BN используется только один раз
  2. Не используйте слой BN
  3. увеличение набора данных
  4. с использованием L1 / L2 в модели

Вот мой код:

def double_conv(in_channels, out_channels_1, out_channels_2):
return nn.Sequential(
    nn.Conv2d(in_channels, out_channels_1, 3, padding=1),
    nn.BatchNorm2d(num_features=out_channels_1, eps=1e-5, momentum=0.1, affine=True),
    nn.ReLU(inplace=True),
    nn.Conv2d(out_channels_1, out_channels_2, 3, padding=1),
    nn.BatchNorm2d(num_features=out_channels_2, eps=1e-5, momentum=0.1, affine=True),
    nn.ReLU(inplace=True)
)


class UNet_one(nn.Module):
def __init__(self, n_class, cfg=[64, 64, 128, 128, 64, 64]):
    super().__init__()
    self.n_class = n_class
    self.cfg = cfg
    self.dconv_down1 = double_conv(1, self.cfg[0], self.cfg[1])
    self.dconv_down2 = double_conv(self.cfg[1], self.cfg[2], self.cfg[3])

    self.maxpool = nn.MaxPool2d(2)
    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    self.dconv_up1 = double_conv(self.cfg[1] + self.cfg[3], self.cfg[4], self.cfg[5])

    self.conv_last = nn.Conv2d(self.cfg[5], n_class, 1)

def forward(self, x):
    conv1 = self.dconv_down1(x)
    x = self.maxpool(conv1)

    x = self.dconv_down2(x)
    x = self.upsample(x)

    x = torch.cat([x, conv1], dim=1)
    x = self.dconv_up1(x)
    out = self.conv_last(x)

    return out


class training():
    def __init__(self, args, model):
        self.args = args
        # train_dataset filepath
        self.dir_img = args.train_imgdir
        self.dir_mask = args.train_maskdir
        self.aug = args.train_aug
        self.aug_num = 1000 
        self.save = args.save #the path of saving model
        #parameters
        self.epochs = args.epochs
        self.train_batch_size = args.train_batchsize
        self.val_batch_size = args.val_batchsize
        self.test_batch_size = args.test_batchsize
        self.patch_h_w = args.patch_h_w
        self.N_patches = args.N_patches
        self.spilt_ratio = args.spilt_ratio

        self.model = model

    def dataset_preparation(self):
        train_set, val_set = get_spilt_trainDataset(self.dir_img, self.dir_mask, spilt_ratio=self.spilt_ratio, patch_h=self.patch_h_w,
                                                    patch_w=self.patch_h_w, N_patches=self.N_patches, inside=True)

        train_set = SimDataset(train_set[0], train_set[1])
        val_set = SimDataset(val_set[0], val_set[1])

        dataloaders = {
            'train': DataLoader(train_set, batch_size=self.train_batch_size, shuffle=True, num_workers=8),
            'val': DataLoader(val_set, batch_size=self.val_batch_size, shuffle=True, num_workers=8)
        }
        return dataloaders

    def train_model(self, dataloaders, optimizer, scheduler):
        best_loss = 1e10

        early_stopping = EarlyStopping(patience=8, verbose=True)

        for epoch in range(self.epochs):
            print('Epoch {}/{}'.format(epoch, self.epochs - 1))
            print('-' * 10)
            since = time.time()
            for phase in ['train', 'val']:
                if phase == 'train':
                    scheduler.step()
                    for param_group in optimizer.param_groups:
                        print("LR", param_group['lr'])
                    self.model.train(mode=True).float()  # Set model to training mode
                else:
                    self.model.eval()  # Set model to evaluate mode
                metrics = defaultdict(float)
                epoch_samples = 0
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device).float()
                    labels = labels.to(device)
                    optimizer.zero_grad()
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = self.model(inputs)
                        loss = calc_loss_entorpy(outputs, labels, metrics)
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
                    epoch_samples += inputs.size(0)

                print_metrics(metrics, epoch_samples, phase)
                epoch_loss = metrics['loss'] / epoch_samples

                if phase == 'val':
                    val_loss_epoch = epoch_loss
                    if epoch_loss < best_loss:
                        best_loss = epoch_loss
                        print("saving best model")
                        torch.save(self.model.state_dict(), self.save)

            #EarlyStopping Judgement
            early_stopping(val_loss_epoch, self.model)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            time_elapsed = time.time() - since
            print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

        print('Best val loss: {:4f}'.format(best_loss))
        return self.model

    def start(self):
        # Training
        dataloaders = self.dataset_preparation()
        optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=1e-3, weight_decay=1e-5)
        # optimizer_ft = optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=1e-3, momentum=0.9, dampening=0, weight_decay=0)
        exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1)
        model = self.train_model(dataloaders= dataloaders, optimizer=optimizer_ft, scheduler=exp_lr_scheduler)
        return model


if __name__ == '__main__':
        model = UNet_one(n_class=1).to(device)
        args = get_args()
        if args.load:
            model.load_state_dict(torch.load(args.model))
        my_training = training(args, model)
        my_training.start()

Мой тестовый код:

class examination(object):

def __init__(self, model, args):
    self.args = args
    self.model = model
    self.batchsize = self.args.test_batchsize
    self.metrics = defaultdict(float)

    self.aug = self.args.test_aug
    self.aug_test_num = 1000
    self.patch_h_w = args.patch_h_w
    self.patch_imgs, self.patch_targets = None, None
    self.test_border_masks = get_full_test_broder_mask()

def dataset_prepared(self):
    #Patches
    self.patch_imgs, self.patch_targets = get_examinationDataset(self.args.test_imgdir, self.args.test_maskdir,
                                                       patch_h=self.patch_h_w, patch_w=self.patch_h_w)
    data_consistency_check(self.patch_imgs, self.patch_targets)
     # Make Dataset
    test_dataset = SimDataset(self.patch_imgs, self.patch_targets)
    test_loader = DataLoader(test_dataset, batch_size=self.batchsize, shuffle=True, num_workers=8)
    return test_loader

def visual(self, inputs, labels, outputs, count):
    inputs = inputs.cpu().numpy()
    labels = labels.cpu().numpy()
    outputs = outputs.cpu().numpy()
    # print('outputs:', outputs.shape)
    outputs = np.transpose(outputs, (0, 2, 3, 1))
    inputs = np.transpose(inputs, (0, 2, 3, 1))
    labels = np.transpose(labels, (0, 2, 3, 1))
    x = 0
    for i in range(x, x+outputs.shape[0]):
        ori_img = inputs[i]
        ori_img = ori_img * 255
        gt_img = labels[i]
        gt_img = gt_img * 255
        output_img = outputs[i]
        output_img = output_img*255
        cv2.imwrite('../outcome/result_ori/ori_'+str(i*count)+'.jpg', ori_img)#
        cv2.imwrite('../outcome/result_label/label_'+str(i*count)+'.jpg', gt_img)
        cv2.imwrite('../outcome/result_out/output_'+str(i*count)+'.jpg', output_img)
        x +=outputs.shape[0]

def start(self):
    self.model.eval()  # Set model to the evaluation mode
    test_loader = self.dataset_prepared()
    epoch_samples = 0
    count = 1 
    output_sum = np.zeros((1620, 1, self.patch_h_w, self.patch_h_w))
    pos=0
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = self.model(inputs.float())
        calc_testmetrics(outputs, labels, self.metrics)
        epoch_samples += inputs.size(0)
        #display
        self.visual(inputs, labels, outputs, count)
        count +=1

    print_metrics(self.metrics, epoch_samples=epoch_samples, phase='test')

if __name__ == '__main__':
    import multiprocessing as mp
    with torch.no_grad():
        args = get_args()
        model = UNet_one(n_class=1).to(device)
        # model = UNet_one_BN(n_class=1).to(device)
        model.load_state_dict(torch.load(args.test_model))
        my_predict = examination(model=model, args=args)
        p = mp.Process(target=my_predict.start()) 
        p.start()
        p.join()

Код, который я использую каждый BN только один раз:

class UNet_one_BN(nn.Module):

    def __init__(self, n_class, cfg=[64, 64, 128, 128, 64, 64]):
        super().__init__()
        self.n_class = n_class
        self.cfg = cfg
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        #first block
        self.conv1_0 = nn.Conv2d(1, self.cfg[0], 3, padding=1)
        self.bn1_0 = nn.BatchNorm2d(self.cfg[0])
        self.conv1_1 = nn.Conv2d(self.cfg[0], self.cfg[1], 3, padding=1)
        self.bn1_1 = nn.BatchNorm2d(self.cfg[1])
        #second block
        self.conv2_0 = nn.Conv2d(self.cfg[1], self.cfg[2], 3, padding=1)
        self.bn2_0 = nn.BatchNorm2d(self.cfg[2])
        self.conv2_1 = nn.Conv2d(self.cfg[2], self.cfg[3], 3, padding=1)
        self.bn2_1 = nn.BatchNorm2d(self.cfg[3])
        #third block
        self.conv3_0 = nn.Conv2d(self.cfg[1] + self.cfg[3], self.cfg[4], 3, padding=1)
        self.bn3_0 = nn.BatchNorm2d(self.cfg[4])
        self.conv3_1 = nn.Conv2d(self.cfg[4], self.cfg[5], 3, padding=1)
        self.bn3_1 = nn.BatchNorm2d(self.cfg[5])
        #last block
        self.conv_last = nn.Conv2d(self.cfg[5], n_class, 1)

    def forward(self, x):
        #first block
        conv1_0 = self.conv1_0(x)
        bn1_0 = self.bn1_0(conv1_0)
        relu1_0 = self.relu(bn1_0)
        conv1_1 = self.conv1_1(relu1_0)
        bn1_1 = self.bn1_1(conv1_1)
        relu1_1 = self.relu(bn1_1)
        mxpool = self.maxpool(relu1_1)
        #second block
        conv2_0 = self.conv2_0(mxpool)
        bn2_0 = self.bn2_0(conv2_0)
        relu2_0 = self.relu(bn2_0)
        conv2_1 = self.conv2_1(relu2_0)
        bn2_1 = self.bn2_1(conv2_1)
        relu2_1 = self.relu(bn2_1)
        upsample = self.upsample(relu2_1)

        concat = torch.cat([upsample, relu1_1], dim=1)
        #third block
        conv3_0 = self.conv3_0(concat)
        bn3_0 = self.bn3_0(conv3_0)
        relu3_0 = self.relu(bn3_0)
        conv3_1 = self.conv3_1(relu3_0)
        bn3_1 = self.bn3_1(conv3_1)
        relu3_1 = self.relu(bn3_1)
        #output
        out = self.conv_last(relu3_1)

        return out
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...