Model.predict возвращает «Matrix-несовместимый размер» - PullRequest
0 голосов
/ 01 ноября 2019

Я обучил модель, сохраненную с помощью обратного вызова ModelCheckpoint. Я загружаю его и запускаю прогнозирование с помощью keras.Model.predict, но получаю ошибку «Matrix size-несовместимый», показанную ниже.

Я проверил, что данные, для которых требуется запустить прогноз, имеют правильную форму, и это делает.

Любое предложение?

Код


    print("***dataset:")
    print(dataset)

    print("***Show shape")
    iterator = dataset.make_one_shot_iterator()
    next_batch = iterator.get_next()
    try:
        while True:
            data = session.run(next_batch)
            print(data.shape)
    except tf.errors.OutOfRangeError:
        pass

    print("***Load model and predict")
    model = tf.keras.models.load_model(model_file)
    model.summary()
    predictions = model.predict(dataset) # Matrix size-incompatible error

Выход

***dataset:
<DatasetV1Adapter shapes: (?, 1, 64, ?), types: tf.float32>
***Show shape
(1, 1, 64, 169)
(1, 1, 64, 169)
(1, 1, 64, 169)
...
(1, 1, 64, 169)
(1, 1, 64, 169)
(1, 1, 64, 169)
(1, 1, 64, 169)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mels (InputLayer)            [(32, 1, 64, 169)]        0         
_________________________________________________________________
l1_conv (Conv2D)             (32, 32, 62, 167)         288       
_________________________________________________________________
l1_bn (BatchNormalization)   (32, 32, 62, 167)         96        
_________________________________________________________________
l1 (Activation)              (32, 32, 62, 167)         0         
_________________________________________________________________
l1_mp (MaxPooling2D)         (32, 32, 30, 83)          0         
_________________________________________________________________
l2_conv (Conv2D)             (32, 32, 28, 81)          9216      
_________________________________________________________________
l2_bn (BatchNormalization)   (32, 32, 28, 81)          96        
_________________________________________________________________
l2 (Activation)              (32, 32, 28, 81)          0         
_________________________________________________________________
l2_mp (MaxPooling2D)         (32, 32, 13, 40)          0         
_________________________________________________________________
l3_conv (Conv2D)             (32, 32, 11, 38)          9216      
_________________________________________________________________
l3_bn (BatchNormalization)   (32, 32, 11, 38)          96        
_________________________________________________________________
l3 (Activation)              (32, 32, 11, 38)          0         
_________________________________________________________________
l3_mp (MaxPooling2D)         (32, 32, 5, 18)           0         
_________________________________________________________________
flatten (Flatten)            (32, 2880)                0         
_________________________________________________________________
logits (Dense)               (32, 100)                 288100    
_________________________________________________________________
dense (Dense)                (32, 10)                  1010      
=================================================================
Total params: 308,118
Trainable params: 307,926
Non-trainable params: 192

***Load model and predict
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-14-edec6ee91517> in <module>
----> 1 predict('/home/jul/data/xenocanto/audio/wav_22050hz_MLR/XC164420.M.wav', '/home/jul/data/ingerop/subset_1572008350/features/actdet_config.json', '/home/jul/data/ingerop/subset_1572008350/features/featex_config.json', '/home/jul/data/ingerop/subset_1572008350/run_1572428779/models/model.05-0.92.h5')

~/dev/phaunos_ml/phaunos_ml/experiments/ingerop_prediction.py in predict(audio_filename, actdet_cfg_file, featex_cfg_file, model_file)
    106     model.summary()
    107 
--> 108     predictions = model.predict(dataset)
    109 
    110     return predictions

~/.miniconda3/envs/phaunos_ml/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
   1076           verbose=verbose,
   1077           steps=steps,
-> 1078           callbacks=callbacks)
   1079 
   1080   def reset_metrics(self):

~/.miniconda3/envs/phaunos_ml/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq, mode, validation_in_fit, prepared_feed_values_from_dataset, steps_name, **kwargs)
    272           # `ins` can be callable in tf.distribute.Strategy + eager case.
    273           actual_inputs = ins() if callable(ins) else ins
--> 274           batch_outs = f(actual_inputs)
    275         except errors.OutOfRangeError:
    276           if is_dataset:

~/.miniconda3/envs/phaunos_ml/lib/python3.6/site-packages/tensorflow/python/keras/backend.py in __call__(self, inputs)
   3290 
   3291     fetched = self._callable_fn(*array_vals,
-> 3292                                 run_metadata=self.run_metadata)
   3293     self._call_fetch_callbacks(fetched[-len(self._fetches):])
   3294     output_structure = nest.pack_sequence_as(

~/.miniconda3/envs/phaunos_ml/lib/python3.6/site-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
   1456         ret = tf_session.TF_SessionRunCallable(self._session._session,
   1457                                                self._handle, args,
-> 1458                                                run_metadata_ptr)
   1459         if run_metadata:
   1460           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument: Matrix size-incompatible: In[0]: [32,90], In[1]: [2880,100]
     [[{{node logits_9/MatMul}}]]
  (1) Invalid argument: Matrix size-incompatible: In[0]: [32,90], In[1]: [2880,100]
     [[{{node logits_9/MatMul}}]]
     [[dense_9/Sigmoid/_2867]]
0 successful operations.
0 derived errors ignored.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...