Я (все еще) пытаюсь реализовать простую Unet сеть, используя Keras на бэкэнде Tensorflow 2.0.
Я смог обучить свою модель, используя 768x768 RGB-изображений. Вот примеры изображений, которые я использовал для обучения:
a) Исходное изображение
b) Маска
Вот моя модель:
import numpy as np
import os
import cv2
import random
from tensorflow.python.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.layers import Input, BatchNormalization, Activation, Dropout
from tensorflow.python.keras.layers.convolutional import Conv2D, Conv2DTranspose
from tensorflow.python.keras.layers.pooling import MaxPooling2D
from tensorflow.python.keras.layers.merge import concatenate
import tensorflow as tf
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
def data_gen(templates_folder, masks_folder, image_width, batch_size):
counter = 0
images_list = os.listdir(templates_folder)
random.shuffle(images_list)
while True:
templates_pack = np.zeros((batch_size, image_width, image_width, 3)).astype('float')
masks_pack = np.zeros((batch_size, image_width, image_width, 1)).astype('float')
for i in range(counter, counter + batch_size):
template = cv2.imread(templates_folder + '/' + images_list[i]) / 255.
templates_pack[i - counter] = template
mask = cv2.imread(masks_folder + '/' + images_list[i], cv2.IMREAD_GRAYSCALE) / 255.
mask = mask.reshape(image_width, image_width, 1) # Add extra dimension for parity with template size [1536 * 1536 * 3]
masks_pack[i - counter] = mask
counter += batch_size
if counter + batch_size >= len(images_list):
counter = 0
random.shuffle(images_list)
yield templates_pack, masks_pack
def get_unet(input_image, n_filters, kernel_size, dropout=0.5):
conv_1 = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), data_format="channels_last", activation='relu', kernel_initializer="he_normal", padding="same")(input_image)
conv_1 = BatchNormalization()(conv_1)
conv_2 = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_1)
conv_2 = BatchNormalization()(conv_2)
pool_1 = MaxPooling2D(pool_size=(2, 2))(conv_2)
pool_1 = Dropout(dropout * 0.5)(pool_1)
conv_3 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_1)
conv_3 = BatchNormalization()(conv_3)
conv_4 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_3)
conv_4 = BatchNormalization()(conv_4)
pool_2 = MaxPooling2D(pool_size=(2, 2))(conv_4)
pool_2 = Dropout(dropout)(pool_2)
conv_5 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_2)
conv_5 = BatchNormalization()(conv_5)
conv_6 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_5)
conv_6 = BatchNormalization()(conv_6)
pool_3 = MaxPooling2D(pool_size=(2, 2))(conv_6)
pool_3 = Dropout(dropout)(pool_3)
conv_7 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_3)
conv_7 = BatchNormalization()(conv_7)
conv_8 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_7)
conv_8 = BatchNormalization()(conv_8)
pool_4 = MaxPooling2D(pool_size=(2, 2))(conv_8)
pool_4 = Dropout(dropout)(pool_4)
conv_9 = Conv2D(filters=n_filters * 16, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_4)
conv_9 = BatchNormalization()(conv_9)
conv_10 = Conv2D(filters=n_filters * 16, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_9)
conv_10 = BatchNormalization()(conv_10)
upconv_1 = Conv2DTranspose(n_filters * 8, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_10)
concat_1 = concatenate([upconv_1, conv_8])
concat_1 = Dropout(dropout)(concat_1)
conv_11 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_1)
conv_11 = BatchNormalization()(conv_11)
conv_12 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_11)
conv_12 = BatchNormalization()(conv_12)
upconv_2 = Conv2DTranspose(n_filters * 4, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_12)
concat_2 = concatenate([upconv_2, conv_6])
concat_2 = Dropout(dropout)(concat_2)
conv_13 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_2)
conv_13 = BatchNormalization()(conv_13)
conv_14 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_13)
conv_14 = BatchNormalization()(conv_14)
upconv_3 = Conv2DTranspose(n_filters * 2, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_14)
concat_3 = concatenate([upconv_3, conv_4])
concat_3 = Dropout(dropout)(concat_3)
conv_15 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_3)
conv_15 = BatchNormalization()(conv_15)
conv_16 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_15)
conv_16 = BatchNormalization()(conv_16)
upconv_4 = Conv2DTranspose(n_filters * 1, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_16)
concat_4 = concatenate([upconv_4, conv_2])
concat_4 = Dropout(dropout)(concat_4)
conv_17 = Conv2D(filters=n_filters * 1, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_4)
conv_17 = BatchNormalization()(conv_17)
conv_18 = Conv2D(filters=n_filters * 1, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_17)
conv_18 = BatchNormalization()(conv_18)
conv_19 = Conv2D(1, (1, 1), activation='sigmoid')(conv_18)
model = Model(inputs=input_image, outputs=conv_19)
return model
callbacks = [EarlyStopping(patience=10, verbose=1),
ReduceLROnPlateau(factor=0.1, patience=3, min_lr=0.00001, verbose=1),
ModelCheckpoint("model-prototype.h5", verbose=1, save_best_only=True, save_weights_only=True)
]
train_templates_path = "E:/train/templates"
train_masks_path = "E:/train/masks"
valid_templates_path = "E:/valid/templates"
valid_masks_path = "E:/valid/masks"
TRAIN_SET_SIZE = len(os.listdir(train_templates_path))
VALID_SET_SIZE = len(os.listdir(valid_templates_path))
BATCH_SIZE = 1
EPOCHS = 100
STEPS_PER_EPOCH = TRAIN_SET_SIZE / BATCH_SIZE
VALIDATION_STEPS = VALID_SET_SIZE / BATCH_SIZE
IMAGE_WIDTH = 1536
train_generator = data_gen(train_templates_path, train_masks_path, IMAGE_WIDTH, batch_size = BATCH_SIZE)
val_generator = data_gen(valid_templates_path, valid_masks_path, IMAGE_WIDTH, batch_size = BATCH_SIZE)
input_image = Input((IMAGE_WIDTH, IMAGE_WIDTH, 3), name='img')
model = get_unet(input_image, n_filters=16, kernel_size = 3, dropout=0.05)
model.compile(optimizer=Adam(lr=0.001), loss="binary_crossentropy", metrics=["accuracy"])
results = model.fit_generator(train_generator, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_data=val_generator, validation_steps=VALIDATION_STEPS, callbacks=callbacks)
test_image = cv2.imread("Test.jpg", 1)
prepared_image = np.expand_dims(test_image, axis=0).astype('float') # Adding extra dimension to fit model's input
prediction = model.predict(prepared_image)
ready_image = prediction[0, :, :, :] # Removing extra dimension to extract the image from the output
cv2.imwrite("Predicted.jpg", ready_image)
Результатом этого кода (Predicted.jpg) является черный квадрат 768x768 независимо от того, какое изображение я даю в качестве ввода. Есть идеи?