Каковы возможные причины плохой производительности Unet? - PullRequest
0 голосов
/ 02 апреля 2020

Я пытаюсь справиться с проблемой двоичной сегментации, используя Unet в Tensorflow 2.0 (модуль Keras). Мои классы сильно разбалансированы, поэтому я должен использовать вес классов (0,03 для фона и 1,0 для переднего плана). В обучающем наборе ~ 2500 образцов и в проверочном наборе ~ 250 образцов.

Образец данных (изображение и его маска):

enter image description here enter image description here

Метри c - Пересечение через Союз. Функция потери - потеря Джакарда. После ~ 10 эпох обучения процесс останавливается на ранней остановке. Потери очень высоки, а показатель c очень низок. Я пытался снизить скорость обучения, но это не сильно помогло. image image

Когда я пытаюсь использовать модель для предсказания, она дает мне просто черный квадрат.

В чем проблема с моей моделью? Я что-то пропустил? Это архитектурный недостаток? Неправильная функция потерь / метри c? Весовая проблема класса? Буду признателен за любую помощь.

Сетевая архитектура:

from contextlib import redirect_stdout
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.layers import Input, BatchNormalization, Activation, Dropout
from tensorflow.python.keras.layers.convolutional import Conv2D, Conv2DTranspose
from tensorflow.python.keras.layers.pooling import MaxPooling2D
from tensorflow.python.keras.layers.merge import concatenate
import tensorflow as tf


config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
IMAGE_WIDTH = 768


def get_unet(input_image, n_filters, kernel_size, dropout=0.5):
    conv_1 = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), data_format="channels_last", activation='relu', kernel_initializer="he_normal", padding="same")(input_image)
    conv_1 = BatchNormalization()(conv_1)
    conv_2 = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_1)
    conv_2 = BatchNormalization()(conv_2)
    pool_1 = MaxPooling2D(pool_size=(2, 2))(conv_2)
    pool_1 = Dropout(dropout * 0.5)(pool_1)

    conv_3 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_1)
    conv_3 = BatchNormalization()(conv_3)
    conv_4 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_3)
    conv_4 = BatchNormalization()(conv_4)
    pool_2 = MaxPooling2D(pool_size=(2, 2))(conv_4)
    pool_2 = Dropout(dropout)(pool_2)

    conv_5 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_2)
    conv_5 = BatchNormalization()(conv_5)
    conv_6 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_5)
    conv_6 = BatchNormalization()(conv_6)
    pool_3 = MaxPooling2D(pool_size=(2, 2))(conv_6)
    pool_3 = Dropout(dropout)(pool_3)

    conv_7 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_3)
    conv_7 = BatchNormalization()(conv_7)
    conv_8 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_7)
    conv_8 = BatchNormalization()(conv_8)
    pool_4 = MaxPooling2D(pool_size=(2, 2))(conv_8)
    pool_4 = Dropout(dropout)(pool_4)

    conv_9 = Conv2D(filters=n_filters * 16, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_4)
    conv_9 = BatchNormalization()(conv_9)
    conv_10 = Conv2D(filters=n_filters * 16, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_9)
    conv_10 = BatchNormalization()(conv_10)

    upconv_1 = Conv2DTranspose(n_filters * 8, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_10)
    concat_1 = concatenate([upconv_1, conv_8])
    concat_1 = Dropout(dropout)(concat_1)
    conv_11 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_1)
    conv_11 = BatchNormalization()(conv_11)
    conv_12 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_11)
    conv_12 = BatchNormalization()(conv_12)

    upconv_2 = Conv2DTranspose(n_filters * 4, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_12)
    concat_2 = concatenate([upconv_2, conv_6])
    concat_2 = Dropout(dropout)(concat_2)
    conv_13 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_2)
    conv_13 = BatchNormalization()(conv_13)
    conv_14 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_13)
    conv_14 = BatchNormalization()(conv_14)

    upconv_3 = Conv2DTranspose(n_filters * 2, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_14)
    concat_3 = concatenate([upconv_3, conv_4])
    concat_3 = Dropout(dropout)(concat_3)
    conv_15 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_3)
    conv_15 = BatchNormalization()(conv_15)
    conv_16 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_15)
    conv_16 = BatchNormalization()(conv_16)

    upconv_4 = Conv2DTranspose(n_filters * 1, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_16)
    concat_4 = concatenate([upconv_4, conv_2])
    concat_4 = Dropout(dropout)(concat_4)
    conv_17 = Conv2D(filters=n_filters * 1, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_4)
    conv_17 = BatchNormalization()(conv_17)
    conv_18 = Conv2D(filters=n_filters * 1, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_17)
    conv_18 = BatchNormalization()(conv_18)

    conv_19 = Conv2D(1, (1, 1), activation='sigmoid')(conv_18)
    model = Model(inputs=input_image, outputs=conv_19)
    return model


input_image = Input((IMAGE_WIDTH, IMAGE_WIDTH, 3), name='img')
model = get_unet(input_image, n_filters=16, kernel_size = 3, dropout=0.05)

with open('binary_unet_summary.txt', 'w') as f:
    with redirect_stdout(f):
        model.summary()

model_json = model.to_json()
with open("my_basic_unet.json", "w") as json_file:
    json_file.write(model_json)

Генератор данных и другие функции:

def calc_weights(masks_folder):
    """Calculate class weights according to classes distribution in a dataset"""
    images_list = os.listdir(masks_folder)
    class_1_numbers = []
    for i in range(len(images_list)):
        mask = cv2.imread(masks_folder + '/' + images_list[i], cv2.IMREAD_GRAYSCALE) / 255.
        class_1_numbers.append(cv2.countNonZero(mask))

    class_1_total = int(statistics.median(class_1_numbers))
    class_0_total = int(IMAGE_WIDTH**2 - class_1_total)
    class_1_weight = 1. # Maximum value to minority class
    class_0_weight = class_1_total / class_0_total # Proportional value to majority class for classes balance
    return [class_0_weight, class_1_weight]


def data_gen(templates_folder, masks_folder, image_width, batch_size):
    """Generate individual batches form dataset"""
    counter = 0
    images_list = os.listdir(templates_folder)
    random.shuffle(images_list)
    while True:
        templates_pack = np.zeros((batch_size, image_width, image_width, 3)).astype('float')
        masks_pack = np.zeros((batch_size, image_width, image_width, 1)).astype('float')
        for i in range(counter, counter + batch_size):
            template = cv2.imread(templates_folder + '/' + images_list[i]) / 255.
            templates_pack[i - counter] = template

            mask = cv2.imread(masks_folder + '/' + images_list[i], cv2.IMREAD_GRAYSCALE) / 255.
            mask = np.expand_dims(mask, axis=2) # Add extra dimension for parity with template size [738 * 738 * 3]
            masks_pack[i - counter] = mask

        counter += batch_size
        if counter + batch_size >= len(images_list):
            counter = 0
            random.shuffle(images_list)
        yield templates_pack, masks_pack


def _gather_channels(x, indexes):
    """Slice tensor along channels axis by given indexes"""
    if tf.keras.backend.image_data_format() == 'channels_last':
        x = tf.keras.backend.permute_dimensions(x, (3, 0, 1, 2))
        x = tf.keras.backend.gather(x, indexes)
        x = tf.keras.backend.permute_dimensions(x, (1, 2, 3, 0))
    else:
        x = tf.keras.backend.permute_dimensions(x, (1, 0, 2, 3))
        x = tf.keras.backend.gather(x, indexes)
        x = tf.keras.backend.permute_dimensions(x, (1, 0, 2, 3))
    return x


def get_reduce_axes(per_image):
    axes = [1, 2] if tf.keras.backend.image_data_format() == 'channels_last' else [2, 3]
    if not per_image:
        axes.insert(0, 0)
    return axes


def gather_channels(*xs, indexes=None):
    """Slice tensors along channels axis by given indexes"""
    if indexes is None:
        return xs
    elif isinstance(indexes, (int)):
        indexes = [indexes]
    xs = [_gather_channels(x, indexes=indexes) for x in xs]
    return xs


def round_if_needed(x, threshold):
    if threshold is not None:
        x = tf.keras.backend.greater(x, threshold)
        x = tf.keras.backend.cast(x, tf.keras.backend.floatx())
    return x


def average(x, per_image=False, class_weights=None):
    if per_image:
        x = tf.keras.backend.mean(x, axis=0)
    if class_weights is not None:
        x = x * class_weights
    return tf.keras.backend.mean(x)


def jaccard_metric(gt_mask, pred_mask, class_weights=1., class_indexes=None, smooth=1e-5, per_image=False, threshold=None):
    r""" 
    Args:
        gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
        pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
        class_weights: 1. or list of class weights, len(weights) = C
        class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
        smooth: value to avoid division by zero
        per_image: if ``True``, metric is calculated as mean over images in batch (B),
            else over whole batch
        threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round

    Returns:
        IoU/Jaccard score in range [0, 1]
        """
    gt_mask, pred_mask = gather_channels(gt_mask, pred_mask, indexes=class_indexes)
    pred_mask = round_if_needed(pred_mask, threshold)
    axes = get_reduce_axes(per_image)

    # score calculation
    intersection = tf.keras.backend.sum(gt_mask * pred_mask, axis=axes)
    union = tf.keras.backend.sum(gt_mask + pred_mask, axis=axes) - intersection

    score = (intersection + smooth) / (union + smooth)
    score = average(score, per_image, class_weights)

    return score


def jaccard_loss(gt_mask, pred_mask, class_weights=1., class_indexes=None, smooth=1e-5, per_image=False, threshold=None):
    return 1 - jaccard_metric(gt_mask, pred_mask, class_weights=class_weights, class_indexes=class_indexes, smooth=smooth, per_image=per_image, threshold=threshold)


def jaccard_loss_wraper(class_weights=1., class_indexes=None, smooth=1e-5, per_image=False, threshold=None):
    def jaccard_loss_keras(gt_mask, pred_mask):
        return jaccard_loss(gt_mask, pred_mask, class_weights=class_weights, class_indexes=class_indexes, smooth=smooth, per_image=per_image, threshold=threshold)

    return jaccard_loss_keras


def jaccard_metric_wraper(class_weights=1., class_indexes=None, smooth=1e-5, per_image=False, threshold=None):
    def jaccard_metric_keras(gt_mask, pred_mask):
        return jaccard_metric(gt_mask, pred_mask, class_weights=class_weights, class_indexes=class_indexes, smooth=smooth, per_image=per_image, threshold=threshold)

    return jaccard_metric_keras

Параметры модели:

IMAGE_WIDTH = 768
callbacks = [
    EarlyStopping(patience=5, verbose=1),
    ReduceLROnPlateau(factor=0.1, patience=3, min_lr=0.00001, verbose=1),
    ModelCheckpoint("best_model.h5", verbose=1, save_best_only=True, save_weights_only=False)
]
train_templates_path = f"E:/Explorium/images/train/templates"
train_masks_path = f"E:/Explorium/images/train/masks"
valid_templates_path = f"E:/Explorium/images/valid/templates"
valid_masks_path = f"E:/Explorium/images/valid/masks"
TRAIN_SET_SIZE = len(os.listdir(train_templates_path))
VALID_SET_SIZE = len(os.listdir(valid_templates_path))
BATCH_SIZE = 4
EPOCHS = 100
STEPS_PER_EPOCH = TRAIN_SET_SIZE / BATCH_SIZE
VALIDATION_STEPS = VALID_SET_SIZE / BATCH_SIZE
train_generator = data_gen(train_templates_path, train_masks_path, IMAGE_WIDTH, batch_size = BATCH_SIZE)
val_generator = data_gen(valid_templates_path, valid_masks_path, IMAGE_WIDTH, batch_size = BATCH_SIZE)

# LOADING ARCHITECTURE AND COMPILING
json_file = open('my_basic_unet.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)

weights = calc_weights(train_masks_path)
loss_function = jaccard_loss_wraper(class_weights=weights)
metric = jaccard_metric_wraper(class_weights=weights)

model.compile(optimizer=Adam(lr=0.0001), loss=loss_function, metrics=[metric])

# TRAINING
print("VERSION CHECK:", tf.__version__, tf.test.is_built_with_cuda(), device_lib.list_local_devices(), sep="\n")
history = model.fit_generator(train_generator, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_data=val_generator, validation_steps=VALIDATION_STEPS, callbacks=callbacks)
...