Я столкнулся с ошибкой ниже при попытке обучить модель классификации нескольких классов (4 класса) для набора данных Image. Несмотря на то, что мой выходной тензор имеет форму 4, я сталкиваюсь с проблемой ниже. Пожалуйста, дайте мне знать, как решить эту проблему.
Epoch 1/10
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-30-01c6f78f4d4f> in <module>
4 epochs=epochs,
5 validation_data=val_data_gen,
----> 6 validation_steps=total_val // batch_size
7 )
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1294 shuffle=shuffle,
1295 initial_epoch=initial_epoch,
-> 1296 steps_name='steps_per_epoch')
1297
1298 def evaluate_generator(self,
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_generator.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, steps_name, **kwargs)
263
264 is_deferred = not model._is_compiled
--> 265 batch_outs = batch_function(*batch_data)
266 if not isinstance(batch_outs, list):
267 batch_outs = [batch_outs]
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics)
1015 self._update_sample_weight_modes(sample_weights=sample_weights)
1016 self._make_train_function()
-> 1017 outputs = self.train_function(ins) # pylint: disable=not-callable
1018
1019 if reset_metrics:
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py in __call__(self, inputs)
3474
3475 fetched = self._callable_fn(*array_vals,
-> 3476 run_metadata=self.run_metadata)
3477 self._call_fetch_callbacks(fetched[-len(self._fetches):])
3478 output_structure = nest.pack_sequence_as(
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in __call__(self, *args, **kwargs)
1470 ret = tf_session.TF_SessionRunCallable(self._session._session,
1471 self._handle, args,
-> 1472 run_metadata_ptr)
1473 if run_metadata:
1474 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: Incompatible shapes: [4,3] vs. [4,4]
[[{{node loss_2/predictions_loss/logistic_loss/mul}}]]
[[loss_2/mul/_19047]]
(1) Invalid argument: Incompatible shapes: [4,3] vs. [4,4]
[[{{node loss_2/predictions_loss/logistic_loss/mul}}]]
0 successful operations.
0 derived errors ignored.
Размер моей партии 4 и ниже - последние несколько слоев моей модели.
conv5_block16_2_conv (Conv2D) (None, 16, 16, 32) 36864 conv5_block16_1_relu[0][0]
__________________________________________________________________________________________________
conv5_block16_concat (Concatena (None, 16, 16, 1024) 0 conv5_block15_concat[0][0]
conv5_block16_2_conv[0][0]
__________________________________________________________________________________________________
bn (BatchNormalization) (None, 16, 16, 1024) 4096 conv5_block16_concat[0][0]
__________________________________________________________________________________________________
relu (Activation) (None, 16, 16, 1024) 0 bn[0][0]
__________________________________________________________________________________________________
avg_pool (GlobalAveragePooling2 (None, 1024) 0 relu[0][0]
__________________________________________________________________________________________________
predictions (Dense) (None, 4) 4100 avg_pool[0][0]
==================================================================================================
Функция потерь
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True))