Я обучил модель, сохраненную с помощью обратного вызова ModelCheckpoint. Я загружаю его и запускаю прогнозирование с помощью keras.Model.predict, но получаю ошибку «Matrix size-несовместимый», показанную ниже.
Я проверил, что данные, для которых требуется запустить прогноз, имеют правильную форму, и это делает.
Любое предложение?
print("***Show shape")
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()
while True:
data = session.run(next_batch)
except tf.errors.OutOfRangeError:
print("***Load model and predict")
model = tf.keras.models.load_model(model_file)
predictions = model.predict(dataset) # Matrix size-incompatible error
<DatasetV1Adapter shapes: (?, 1, 64, ?), types: tf.float32>
***Show shape
(1, 1, 64, 169)
(1, 1, 64, 169)
(1, 1, 64, 169)
(1, 1, 64, 169)
(1, 1, 64, 169)
(1, 1, 64, 169)
(1, 1, 64, 169)
Model: "model"
Layer (type) Output Shape Param #
mels (InputLayer) [(32, 1, 64, 169)] 0
l1_conv (Conv2D) (32, 32, 62, 167) 288
l1_bn (BatchNormalization) (32, 32, 62, 167) 96
l1 (Activation) (32, 32, 62, 167) 0
l1_mp (MaxPooling2D) (32, 32, 30, 83) 0
l2_conv (Conv2D) (32, 32, 28, 81) 9216
l2_bn (BatchNormalization) (32, 32, 28, 81) 96
l2 (Activation) (32, 32, 28, 81) 0
l2_mp (MaxPooling2D) (32, 32, 13, 40) 0
l3_conv (Conv2D) (32, 32, 11, 38) 9216
l3_bn (BatchNormalization) (32, 32, 11, 38) 96
l3 (Activation) (32, 32, 11, 38) 0
l3_mp (MaxPooling2D) (32, 32, 5, 18) 0
flatten (Flatten) (32, 2880) 0
logits (Dense) (32, 100) 288100
dense (Dense) (32, 10) 1010
Total params: 308,118
Trainable params: 307,926
Non-trainable params: 192
***Load model and predict
InvalidArgumentError Traceback (most recent call last)
<ipython-input-14-edec6ee91517> in <module>
----> 1 predict('/home/jul/data/xenocanto/audio/wav_22050hz_MLR/XC164420.M.wav', '/home/jul/data/ingerop/subset_1572008350/features/actdet_config.json', '/home/jul/data/ingerop/subset_1572008350/features/featex_config.json', '/home/jul/data/ingerop/subset_1572008350/run_1572428779/models/model.05-0.92.h5')
~/dev/phaunos_ml/phaunos_ml/experiments/ingerop_prediction.py in predict(audio_filename, actdet_cfg_file, featex_cfg_file, model_file)
106 model.summary()
--> 108 predictions = model.predict(dataset)
110 return predictions
~/.miniconda3/envs/phaunos_ml/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
1076 verbose=verbose,
1077 steps=steps,
-> 1078 callbacks=callbacks)
1080 def reset_metrics(self):
~/.miniconda3/envs/phaunos_ml/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq, mode, validation_in_fit, prepared_feed_values_from_dataset, steps_name, **kwargs)
272 # `ins` can be callable in tf.distribute.Strategy + eager case.
273 actual_inputs = ins() if callable(ins) else ins
--> 274 batch_outs = f(actual_inputs)
275 except errors.OutOfRangeError:
276 if is_dataset:
~/.miniconda3/envs/phaunos_ml/lib/python3.6/site-packages/tensorflow/python/keras/backend.py in __call__(self, inputs)
3291 fetched = self._callable_fn(*array_vals,
-> 3292 run_metadata=self.run_metadata)
3293 self._call_fetch_callbacks(fetched[-len(self._fetches):])
3294 output_structure = nest.pack_sequence_as(
~/.miniconda3/envs/phaunos_ml/lib/python3.6/site-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
1456 ret = tf_session.TF_SessionRunCallable(self._session._session,
1457 self._handle, args,
-> 1458 run_metadata_ptr)
1459 if run_metadata:
1460 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: Matrix size-incompatible: In[0]: [32,90], In[1]: [2880,100]
[[{{node logits_9/MatMul}}]]
(1) Invalid argument: Matrix size-incompatible: In[0]: [32,90], In[1]: [2880,100]
[[{{node logits_9/MatMul}}]]
0 successful operations.
0 derived errors ignored.