Почему метод прогнозирования объекта модели Keras не допускает размер пакета 1? - PullRequest
0 голосов
/ 19 июня 2019

Я обучил тонко настроенную модель ELMo, используя Keras, которая будет предсказывать только с batch_size из 2. Вот пример кода:

model_input = np.repeat(np.array([str(user_input)]), 2)
model.predict(model_input, batch_size=2)

Этот код работает отлично. Тем не менее, если я запустите это:

model_input = np.array([str(user_input)])
model.predict(model_input, batch_size=1)

Я получаю эту ошибку:

Traceback (most recent call last):
  File "nlu/nlu_classifiers/elmo_scratch.py", line 67, in <module>
    main()
  File "nlu/nlu_classifiers/elmo_scratch.py", line 61, in main
    model.predict(model_input, batch_size=1)
  File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/keras/engine/training.py", line 1169, in predict
    steps=steps)
  File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/keras/engine/training_arrays.py", line 294, in predict_loop
    batch_outs = f(ins_batch)
  File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2715, in __call__
    return self._call(inputs)
  File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2675, in _call
    fetched = self._callable_fn(*array_vals)
  File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1439, in __call__
    run_metadata_ptr)
  File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 528, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: input must be a vector, got shape: []
         [[{{node lambda_1/module_apply_default/StringSplit}}]]

Почему это? И есть ли способ предсказать на одном примере без использования np.repeat? Это не большая проблема, потому что это в основном та же скорость, но это меня немного раздражало.

1 Ответ

1 голос
/ 19 июня 2019

np.repeat() упаковывает np.array([str(user_input)]) в массив, но вы не вызываете np.repeat(), когда ваш batch_size равен 1, поэтому model_input - это одномерный массив вместо двумерного массива. Попробуйте это:

model_input = np.array([[str(user_input)]])
model.predict(model_input, batch_size=1)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...