Я учусь использовать функциональный API Tenorsflow Keras. Я строю нейронную сеть, которая принимает 2 входа и производит двоичный вывод (0/1).
Код для модели:
train_datagen = ImageDataGenerator(rescale=1./255,horizontal_flip=True,rotation_range=15,validation_split=0.2)
train_generator = train_datagen.flow_from_directory('lfw',
target_size=(224,224),
class_mode='binary',
shuffle=True,
seed=24
)
input1 = Input(shape=(224,224,3),name='Input_1')
input2 = Input(shape=(224,224,3),name="Input_2")
vgg = VGG16(False,'imagenet',input_shape=(224,224,3))
vgg.trainable = False
encoder1 = vgg(input1)
encoder2 = vgg(input2)
difference_layer = Subtract(name="Difference_Layer")([encoder1,encoder2])
h1 = Flatten(name='Flatten')(difference_layer)
output = Dense(1,activation='sigmoid',name="Output")(h1)
model = tf.keras.Model(inputs=[input1,input2],outputs=[output])
model.compile(optimizer='adam',loss=tf.keras.losses.MSE,metrics=['acc','mse','loss'])
model.fit(train_generator,epochs=10)
Я получаю эту ошибку:
AssertionError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:505 train_function *
outputs = self.distribute_strategy.run(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:951 run **
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
return fn(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:465 train_step **
y_pred = self(x, training=True)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:927 __call__
outputs = call_fn(cast_inputs, *args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py:714 call
convert_kwargs_to_constants=base_layer_utils.call_context().saving)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py:894 _run_internal_graph
assert str(id(x)) in tensor_dict, 'Could not compute output ' + str(x)
AssertionError: Could not compute output Tensor("Output_4/Identity:0", shape=(None, 1), dtype=float32)