Я пытаюсь использовать Keras для построения UNet для сегментации изображений, где я добавил слой Batch_Normalization.После тренировки в 60 эпох модель достигла 0,85 кубика, а результат валидации - 0,8.Но в тесте результат сегментации модели всегда был равен 0.
Когда я удалил слой Batch_Nornalization, результат сегментации был низким, но не 0.
Это код для моей модели :
def conv_bn(x, flt, k=3, rate=1):
x = Conv2D(flt, (k, k), dilation_rate=rate, padding='same')(x)
# x = BatchNormalization(axis=-1, epsilon=0.001, momentum=0.99, center=True, scale=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
return x
def get_unet(imgs_row, imgs_col, flt=64, drop_out=False):
print("start building NN")
inputs = Input((imgs_row, imgs_col, 1))
conv1 = conv_bn(x=inputs, flt=flt, k=3)
conv1 = Dropout(0.2)(conv1) if drop_out else conv1
conv1 = conv_bn(x=conv1, flt=flt, k=3)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = conv_bn(x=pool1, flt=flt * 2, k=3)
conv2 = Dropout(0.2)(conv2) if drop_out else conv2
conv2 = conv_bn(x=conv2, flt=flt * 2, k=3)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = conv_bn(x=pool2, flt=flt * 4, k=3)
conv3 = Dropout(0.2)(conv3) if drop_out else conv3
conv3 = conv_bn(x=conv3, flt=flt * 4, k=3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = conv_bn(x=pool3, flt=flt * 8, k=3)
conv4 = Dropout(0.2)(conv4) if drop_out else conv4
conv4 = conv_bn(x=conv4, flt=flt * 8, k=3)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = conv_bn(x=pool4, flt=flt * 16, k=3, rate=1)
conv5 = Dropout(0.2)(conv5) if drop_out else conv5
conv5 = conv_bn(x=conv5, flt=flt * 8, k=3, rate=1)
up6 = concatenate([Conv2DTranspose(flt * 8, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
conv6 = conv_bn(x=up6, flt=flt * 8, k=3)
conv6 = Dropout(0.2)(conv6) if drop_out else conv6
conv6 = conv_bn(x=conv6, flt=flt * 4, k=3)
up7 = concatenate([Conv2DTranspose(flt * 4, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
conv7 = conv_bn(x=up7, flt=flt * 4, k=3)
conv7 = Dropout(0.2)(conv7) if drop_out else conv7
conv7 = conv_bn(x=conv7, flt=flt * 2, k=3)
up8 = concatenate([Conv2DTranspose(flt * 2, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
conv8 = conv_bn(x=up8, flt=flt * 2, k=3)
conv8 = Dropout(0.2)(conv8) if drop_out else conv8
conv8 = conv_bn(x=conv8, flt=flt, k=3)
up9 = concatenate([Conv2DTranspose(flt, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
conv9 = conv_bn(x=up9, flt=flt, k=3)
conv9 = Dropout(0.2)(conv9) if drop_out else conv9
conv9 = conv_bn(x=conv9, flt=flt, k=3)
conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
model = Model(input=inputs, output=[conv10])
return model
Это мой тренировочный код:
data_gen_args = dict(featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=25,
shear_range=0.2,
horizontal_flip=True,
vertical_flip=True)
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)
seed = 50
image_generator = image_datagen.flow(train_data, seed=seed, batch_size=3)
mask_generator = mask_datagen.flow(train_label, seed=seed, batch_size=3)
train_generator = zip(image_generator, mask_generator)
model = get_unet(imgs_row, imgs_col)
sgd = SGD(lr=init_lr, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss=dice_coef_loss, metrics=[dice_coef])
ver = 'unet'
csv_logger = CSVLogger(os.path.join('logs', ver + ".csv"))
# reduce_lr = ReduceLROnPlateau(monitor='loss', patience=10, mode='min', factor=0.7)
model_checkpoint = ModelCheckpoint(
os.path.join('checkpoints', ver + ".h5"),
monitor='loss',
save_best_only=True,
period=1)
model.fit_generator(train_generator, steps_per_epoch=(len(train_data)), epochs=50,
callbacks=[model_checkpoint, csv_logger, reduce_lr],
validation_data=(valid_data, valid_label))
Это мой тестовый код:
model = load_model(os.path.join('checkpoints', model_name),
custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})
data, label = get_one_patient_end2end(patient)
pred_mask = np.zeros_like(data)
for sli in range(data.shape[0]):
try:
image = preprocess_front(preprocess(data[sli]))
pre_probability = model.predict(image)
out_ori = (pre_probability > 0.5).astype(np.uint8)
pred_mask[sli] = out_ori[0, :, :, 0]
cur_dsc, _, _, _ = DSC_computation(label[sli], pred_mask[sli])
print('Predict %dth slice: ' % sli, cur_dsc)
Обычно результат сегментации должен быть немного ниже 0,8, ноне всегда должно быть 0.