Я попытался запустить код модели keras:
def new_model(input_shape=(4, 128, 128, 128), n_base_filters=16, depth=5, dropout_rate=0.3,
n_segmentation_levels=3, n_labels=4, optimizer=Adam, initial_learning_rate=5e-4,
loss_function=weighted_dice_coefficient_loss, activation_name="sigmoid",metrics=dice_coefficient):
inputs = Input(input_shape)
split = Lambda( lambda x: tf.split(x,num_or_size_splits=4,axis=1))(inputs)
p_4 = tf.keras.layers.Conv3D(
name = "p_4",
filters = 1,
kernel_size = (1,1,1),
strides = (1,1,1),
padding = "valid",
dilation_rate = (1,1,1),
use_bias = True,
kernel_initializer = "glorot_uniform",
bias_initializer = "zeros",
)(split[0])
p_5 = tf.keras.layers.Conv3D(
name = "p_5",
filters = 1,
kernel_size = (1,1,1),
strides = (1,1,1),
padding = "valid",
dilation_rate = (1,1,1),
use_bias = True,
kernel_initializer = "glorot_uniform",
bias_initializer = "zeros",
)(split[1])
p_6 = tf.keras.layers.Conv3D(
name = "p_6",
filters = 1,
kernel_size = (1,1,1),
strides = (1,1,1),
padding = "valid",
dilation_rate = (1,1,1),
use_bias = True,
kernel_initializer = "glorot_uniform",
bias_initializer = "zeros",
)(split[2])
p_7 = tf.keras.layers.Conv3D(
name = "p_7",
filters = 1,
kernel_size = (1,1,1),
strides = (1,1,1),
padding = "valid",
dilation_rate = (1,1,1),
use_bias = True,
kernel_initializer = "glorot_uniform",
bias_initializer = "zeros",
)(split[3])
p_8 = tf.keras.layers.Activation(
name = "p_8",
activation = activation_name,
)(
tf.keras.layers.Concatenate(name='concat_p_4_p_5_p_6_p_7', axis=0)([
p_4,
p_5,
p_6,
p_7
])
)
output_layer = p_8
model = Model(inputs=inputs, outputs=output_layer)
if not isinstance(metrics, list):
metrics = [metrics]
# model.compile(optimizer=optimizer(lr=initial_learning_rate), loss=loss_function)
model.compile(optimizer=optimizer(lr=initial_learning_rate), loss=loss_function, metrics=metrics)
return model
Когда я запустил эту модель, я получил следующую ошибку. Я сомневаюсь, что проблема вызвана из-за лямбда-оператора.
Invalid ArgumentError: Incompatible shapes: [6291456] vs. [8388608]
[[node mul (defined at ../unet3d/metrics.py:9) ]] [Op:__inference_train_function_921]
Errors may have originated from an input operation.
Input Source operations connected to node mul:
Reshape (defined at /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:2700)
Function call stack:
train_function