В моей сети 2 входа, одно трехмерное изображение в градациях серого и одно двумерное цветное изображение.В конце у меня есть один 2D-вывод, такой же формы, как и входное изображение.
Моя проблема: я хочу объединить 2 функции потерь, по одной для каждой входной ветви сети.Одна функция потерь должна сравнивать прогноз сети с двухмерной истинностью серой шкалы, а другая - с двумерным вводом цвета.Эти основополагающие факты хранятся в отдельных файлах.
Я не понимаю, как работает y_true и как я могу сказать, чтобы он смотрел на мои файлы основополагающих правд.Я думаю, что моя проблема в основном из-за синтаксиса, и после просмотра других постов на SO я просто запутался.
Вот код, который я придумал.Конечно, это не работает, но оно должно дать представление о том, к чему я стремлюсь.
def custom_loss(groundtruth_grayscale, groundtruth_colour):
def loss(y_true, y_pred):
loss_grayscale = ssim(y_pred, groundtruth_grayscale)
loss_colour = ssim(y_pred, groundtruth_colour)
ssim_loss = loss_grayscale + loss_colour
l1_loss_grayscale = l1(y_pred, groundtruth_grayscale)
l1_loss_colour = l1(y_pred, groundtruth_colour)
l1_loss = l1_loss_grayscale + l1_loss_colour
return ssim_loss + l1_loss
return loss
# images_groundtruth_grayscale is a variable containing all groundtruth_grayscale images
model_combined.optimizer = tf.keras.optimizers.Adam(learning_rate = 0.002).minimize(custom_loss(images_groundtruth_grayscale, images_groundtruth_colour), var_list = model_combined.trainable_variables)
Если это поможет, краткая информация о модели:
model_combined.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, None, None, 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, None, None, 1 9728 input_1[0][0]
__________________________________________________________________________________________________
input_2 (InputLayer) [(None, None, None, 0
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, None, None, 1 512 conv2d[0][0]
__________________________________________________________________________________________________
conv3d (Conv3D) (None, None, None, N 16128 input_2[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, None, None, 6 204864 batch_normalization[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, None, None, N 512 conv3d[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, None, None, 6 256 conv2d_1[0][0]
__________________________________________________________________________________________________
conv3d_1 (Conv3D) (None, None, None, N 1024064 batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, None, None, 3 18464 batch_normalization_1[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, None, None, N 256 conv3d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, None, None, 3 128 conv2d_2[0][0]
__________________________________________________________________________________________________
conv3d_2 (Conv3D) (None, None, None, N 55328 batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, None, None, 1 4624 batch_normalization_2[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, None, None, N 128 conv3d_2[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, None, None, 1 64 conv2d_3[0][0]
__________________________________________________________________________________________________
conv3d_3 (Conv3D) (None, None, None, N 13840 batch_normalization_6[0][0]
__________________________________________________________________________________________________
tf_op_layer_ExpandDims (TensorF [(None, None, None, 0 batch_normalization_3[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, None, None, N 64 conv3d_3[0][0]
__________________________________________________________________________________________________
tf_op_layer_add (TensorFlowOpLa [(None, None, None, 0 tf_op_layer_ExpandDims[0][0]
batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv3d_4 (Conv3D) (None, None, None, N 272 tf_op_layer_add[0][0]
==================================================================================================
Total params: 1,349,232
Trainable params: 1,348,272
Non-trainable params: 960
__________________________________________________________________________________________________