Несколько функций потерь для нескольких входов в Керасе - PullRequest
0 голосов
/ 27 июня 2019

В моей сети 2 входа, одно трехмерное изображение в градациях серого и одно двумерное цветное изображение.В конце у меня есть один 2D-вывод, такой же формы, как и входное изображение.

Моя проблема: я хочу объединить 2 функции потерь, по одной для каждой входной ветви сети.Одна функция потерь должна сравнивать прогноз сети с двухмерной истинностью серой шкалы, а другая - с двумерным вводом цвета.Эти основополагающие факты хранятся в отдельных файлах.

Я не понимаю, как работает y_true и как я могу сказать, чтобы он смотрел на мои файлы основополагающих правд.Я думаю, что моя проблема в основном из-за синтаксиса, и после просмотра других постов на SO я просто запутался.

Вот код, который я придумал.Конечно, это не работает, но оно должно дать представление о том, к чему я стремлюсь.

def custom_loss(groundtruth_grayscale, groundtruth_colour):
    def loss(y_true, y_pred):
        loss_grayscale = ssim(y_pred, groundtruth_grayscale)
        loss_colour = ssim(y_pred, groundtruth_colour)
        ssim_loss = loss_grayscale + loss_colour

        l1_loss_grayscale = l1(y_pred, groundtruth_grayscale)
        l1_loss_colour = l1(y_pred, groundtruth_colour)
        l1_loss = l1_loss_grayscale + l1_loss_colour

        return ssim_loss + l1_loss
    return loss

# images_groundtruth_grayscale is a variable containing all groundtruth_grayscale images
model_combined.optimizer = tf.keras.optimizers.Adam(learning_rate = 0.002).minimize(custom_loss(images_groundtruth_grayscale, images_groundtruth_colour), var_list = model_combined.trainable_variables)

Если это поможет, краткая информация о модели:

model_combined.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, None, None, 1 9728        input_1[0][0]                    
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, None, None, 1 512         conv2d[0][0]                     
__________________________________________________________________________________________________
conv3d (Conv3D)                 (None, None, None, N 16128       input_2[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, None, None, 6 204864      batch_normalization[0][0]        
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, None, None, N 512         conv3d[0][0]                     
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, None, None, 6 256         conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv3d_1 (Conv3D)               (None, None, None, N 1024064     batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, None, None, 3 18464       batch_normalization_1[0][0]      
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, None, None, N 256         conv3d_1[0][0]                   
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, None, None, 3 128         conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv3d_2 (Conv3D)               (None, None, None, N 55328       batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, None, None, 1 4624        batch_normalization_2[0][0]      
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, None, None, N 128         conv3d_2[0][0]                   
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, None, None, 1 64          conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv3d_3 (Conv3D)               (None, None, None, N 13840       batch_normalization_6[0][0]      
__________________________________________________________________________________________________
tf_op_layer_ExpandDims (TensorF [(None, None, None,  0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, None, None, N 64          conv3d_3[0][0]                   
__________________________________________________________________________________________________
tf_op_layer_add (TensorFlowOpLa [(None, None, None,  0           tf_op_layer_ExpandDims[0][0]     
                                                                 batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv3d_4 (Conv3D)               (None, None, None, N 272         tf_op_layer_add[0][0]            
==================================================================================================
Total params: 1,349,232
Trainable params: 1,348,272
Non-trainable params: 960
__________________________________________________________________________________________________
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...