Я пытаюсь справиться с проблемой двоичной сегментации, используя Unet в Tensorflow 2.0 (модуль Keras). Мои классы сильно разбалансированы, поэтому я должен использовать вес классов (0,03 для фона и 1,0 для переднего плана). В обучающем наборе ~ 2500 образцов и в проверочном наборе ~ 250 образцов.
Образец данных (изображение и его маска):
Метри c - Пересечение через Союз. Функция потери - потеря Джакарда. После ~ 10 эпох обучения процесс останавливается на ранней остановке. Потери очень высоки, а показатель c очень низок. Я пытался снизить скорость обучения, но это не сильно помогло.
Когда я пытаюсь использовать модель для предсказания, она дает мне просто черный квадрат.
В чем проблема с моей моделью? Я что-то пропустил? Это архитектурный недостаток? Неправильная функция потерь / метри c? Весовая проблема класса? Буду признателен за любую помощь.
Сетевая архитектура:
from contextlib import redirect_stdout
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.layers import Input, BatchNormalization, Activation, Dropout
from tensorflow.python.keras.layers.convolutional import Conv2D, Conv2DTranspose
from tensorflow.python.keras.layers.pooling import MaxPooling2D
from tensorflow.python.keras.layers.merge import concatenate
import tensorflow as tf
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
IMAGE_WIDTH = 768
def get_unet(input_image, n_filters, kernel_size, dropout=0.5):
conv_1 = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), data_format="channels_last", activation='relu', kernel_initializer="he_normal", padding="same")(input_image)
conv_1 = BatchNormalization()(conv_1)
conv_2 = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_1)
conv_2 = BatchNormalization()(conv_2)
pool_1 = MaxPooling2D(pool_size=(2, 2))(conv_2)
pool_1 = Dropout(dropout * 0.5)(pool_1)
conv_3 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_1)
conv_3 = BatchNormalization()(conv_3)
conv_4 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_3)
conv_4 = BatchNormalization()(conv_4)
pool_2 = MaxPooling2D(pool_size=(2, 2))(conv_4)
pool_2 = Dropout(dropout)(pool_2)
conv_5 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_2)
conv_5 = BatchNormalization()(conv_5)
conv_6 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_5)
conv_6 = BatchNormalization()(conv_6)
pool_3 = MaxPooling2D(pool_size=(2, 2))(conv_6)
pool_3 = Dropout(dropout)(pool_3)
conv_7 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_3)
conv_7 = BatchNormalization()(conv_7)
conv_8 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_7)
conv_8 = BatchNormalization()(conv_8)
pool_4 = MaxPooling2D(pool_size=(2, 2))(conv_8)
pool_4 = Dropout(dropout)(pool_4)
conv_9 = Conv2D(filters=n_filters * 16, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_4)
conv_9 = BatchNormalization()(conv_9)
conv_10 = Conv2D(filters=n_filters * 16, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_9)
conv_10 = BatchNormalization()(conv_10)
upconv_1 = Conv2DTranspose(n_filters * 8, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_10)
concat_1 = concatenate([upconv_1, conv_8])
concat_1 = Dropout(dropout)(concat_1)
conv_11 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_1)
conv_11 = BatchNormalization()(conv_11)
conv_12 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_11)
conv_12 = BatchNormalization()(conv_12)
upconv_2 = Conv2DTranspose(n_filters * 4, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_12)
concat_2 = concatenate([upconv_2, conv_6])
concat_2 = Dropout(dropout)(concat_2)
conv_13 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_2)
conv_13 = BatchNormalization()(conv_13)
conv_14 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_13)
conv_14 = BatchNormalization()(conv_14)
upconv_3 = Conv2DTranspose(n_filters * 2, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_14)
concat_3 = concatenate([upconv_3, conv_4])
concat_3 = Dropout(dropout)(concat_3)
conv_15 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_3)
conv_15 = BatchNormalization()(conv_15)
conv_16 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_15)
conv_16 = BatchNormalization()(conv_16)
upconv_4 = Conv2DTranspose(n_filters * 1, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_16)
concat_4 = concatenate([upconv_4, conv_2])
concat_4 = Dropout(dropout)(concat_4)
conv_17 = Conv2D(filters=n_filters * 1, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_4)
conv_17 = BatchNormalization()(conv_17)
conv_18 = Conv2D(filters=n_filters * 1, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_17)
conv_18 = BatchNormalization()(conv_18)
conv_19 = Conv2D(1, (1, 1), activation='sigmoid')(conv_18)
model = Model(inputs=input_image, outputs=conv_19)
return model
input_image = Input((IMAGE_WIDTH, IMAGE_WIDTH, 3), name='img')
model = get_unet(input_image, n_filters=16, kernel_size = 3, dropout=0.05)
with open('binary_unet_summary.txt', 'w') as f:
with redirect_stdout(f):
model.summary()
model_json = model.to_json()
with open("my_basic_unet.json", "w") as json_file:
json_file.write(model_json)
Генератор данных и другие функции:
def calc_weights(masks_folder):
"""Calculate class weights according to classes distribution in a dataset"""
images_list = os.listdir(masks_folder)
class_1_numbers = []
for i in range(len(images_list)):
mask = cv2.imread(masks_folder + '/' + images_list[i], cv2.IMREAD_GRAYSCALE) / 255.
class_1_numbers.append(cv2.countNonZero(mask))
class_1_total = int(statistics.median(class_1_numbers))
class_0_total = int(IMAGE_WIDTH**2 - class_1_total)
class_1_weight = 1. # Maximum value to minority class
class_0_weight = class_1_total / class_0_total # Proportional value to majority class for classes balance
return [class_0_weight, class_1_weight]
def data_gen(templates_folder, masks_folder, image_width, batch_size):
"""Generate individual batches form dataset"""
counter = 0
images_list = os.listdir(templates_folder)
random.shuffle(images_list)
while True:
templates_pack = np.zeros((batch_size, image_width, image_width, 3)).astype('float')
masks_pack = np.zeros((batch_size, image_width, image_width, 1)).astype('float')
for i in range(counter, counter + batch_size):
template = cv2.imread(templates_folder + '/' + images_list[i]) / 255.
templates_pack[i - counter] = template
mask = cv2.imread(masks_folder + '/' + images_list[i], cv2.IMREAD_GRAYSCALE) / 255.
mask = np.expand_dims(mask, axis=2) # Add extra dimension for parity with template size [738 * 738 * 3]
masks_pack[i - counter] = mask
counter += batch_size
if counter + batch_size >= len(images_list):
counter = 0
random.shuffle(images_list)
yield templates_pack, masks_pack
def _gather_channels(x, indexes):
"""Slice tensor along channels axis by given indexes"""
if tf.keras.backend.image_data_format() == 'channels_last':
x = tf.keras.backend.permute_dimensions(x, (3, 0, 1, 2))
x = tf.keras.backend.gather(x, indexes)
x = tf.keras.backend.permute_dimensions(x, (1, 2, 3, 0))
else:
x = tf.keras.backend.permute_dimensions(x, (1, 0, 2, 3))
x = tf.keras.backend.gather(x, indexes)
x = tf.keras.backend.permute_dimensions(x, (1, 0, 2, 3))
return x
def get_reduce_axes(per_image):
axes = [1, 2] if tf.keras.backend.image_data_format() == 'channels_last' else [2, 3]
if not per_image:
axes.insert(0, 0)
return axes
def gather_channels(*xs, indexes=None):
"""Slice tensors along channels axis by given indexes"""
if indexes is None:
return xs
elif isinstance(indexes, (int)):
indexes = [indexes]
xs = [_gather_channels(x, indexes=indexes) for x in xs]
return xs
def round_if_needed(x, threshold):
if threshold is not None:
x = tf.keras.backend.greater(x, threshold)
x = tf.keras.backend.cast(x, tf.keras.backend.floatx())
return x
def average(x, per_image=False, class_weights=None):
if per_image:
x = tf.keras.backend.mean(x, axis=0)
if class_weights is not None:
x = x * class_weights
return tf.keras.backend.mean(x)
def jaccard_metric(gt_mask, pred_mask, class_weights=1., class_indexes=None, smooth=1e-5, per_image=False, threshold=None):
r"""
Args:
gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
class_weights: 1. or list of class weights, len(weights) = C
class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
smooth: value to avoid division by zero
per_image: if ``True``, metric is calculated as mean over images in batch (B),
else over whole batch
threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round
Returns:
IoU/Jaccard score in range [0, 1]
"""
gt_mask, pred_mask = gather_channels(gt_mask, pred_mask, indexes=class_indexes)
pred_mask = round_if_needed(pred_mask, threshold)
axes = get_reduce_axes(per_image)
# score calculation
intersection = tf.keras.backend.sum(gt_mask * pred_mask, axis=axes)
union = tf.keras.backend.sum(gt_mask + pred_mask, axis=axes) - intersection
score = (intersection + smooth) / (union + smooth)
score = average(score, per_image, class_weights)
return score
def jaccard_loss(gt_mask, pred_mask, class_weights=1., class_indexes=None, smooth=1e-5, per_image=False, threshold=None):
return 1 - jaccard_metric(gt_mask, pred_mask, class_weights=class_weights, class_indexes=class_indexes, smooth=smooth, per_image=per_image, threshold=threshold)
def jaccard_loss_wraper(class_weights=1., class_indexes=None, smooth=1e-5, per_image=False, threshold=None):
def jaccard_loss_keras(gt_mask, pred_mask):
return jaccard_loss(gt_mask, pred_mask, class_weights=class_weights, class_indexes=class_indexes, smooth=smooth, per_image=per_image, threshold=threshold)
return jaccard_loss_keras
def jaccard_metric_wraper(class_weights=1., class_indexes=None, smooth=1e-5, per_image=False, threshold=None):
def jaccard_metric_keras(gt_mask, pred_mask):
return jaccard_metric(gt_mask, pred_mask, class_weights=class_weights, class_indexes=class_indexes, smooth=smooth, per_image=per_image, threshold=threshold)
return jaccard_metric_keras
Параметры модели:
IMAGE_WIDTH = 768
callbacks = [
EarlyStopping(patience=5, verbose=1),
ReduceLROnPlateau(factor=0.1, patience=3, min_lr=0.00001, verbose=1),
ModelCheckpoint("best_model.h5", verbose=1, save_best_only=True, save_weights_only=False)
]
train_templates_path = f"E:/Explorium/images/train/templates"
train_masks_path = f"E:/Explorium/images/train/masks"
valid_templates_path = f"E:/Explorium/images/valid/templates"
valid_masks_path = f"E:/Explorium/images/valid/masks"
TRAIN_SET_SIZE = len(os.listdir(train_templates_path))
VALID_SET_SIZE = len(os.listdir(valid_templates_path))
BATCH_SIZE = 4
EPOCHS = 100
STEPS_PER_EPOCH = TRAIN_SET_SIZE / BATCH_SIZE
VALIDATION_STEPS = VALID_SET_SIZE / BATCH_SIZE
train_generator = data_gen(train_templates_path, train_masks_path, IMAGE_WIDTH, batch_size = BATCH_SIZE)
val_generator = data_gen(valid_templates_path, valid_masks_path, IMAGE_WIDTH, batch_size = BATCH_SIZE)
# LOADING ARCHITECTURE AND COMPILING
json_file = open('my_basic_unet.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)
weights = calc_weights(train_masks_path)
loss_function = jaccard_loss_wraper(class_weights=weights)
metric = jaccard_metric_wraper(class_weights=weights)
model.compile(optimizer=Adam(lr=0.0001), loss=loss_function, metrics=[metric])
# TRAINING
print("VERSION CHECK:", tf.__version__, tf.test.is_built_with_cuda(), device_lib.list_local_devices(), sep="\n")
history = model.fit_generator(train_generator, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_data=val_generator, validation_steps=VALIDATION_STEPS, callbacks=callbacks)