Я обучил тонко настроенную модель ELMo, используя Keras, которая будет предсказывать только с batch_size
из 2
. Вот пример кода:
model_input = np.repeat(np.array([str(user_input)]), 2)
model.predict(model_input, batch_size=2)
Этот код работает отлично. Тем не менее, если я запустите это:
model_input = np.array([str(user_input)])
model.predict(model_input, batch_size=1)
Я получаю эту ошибку:
Traceback (most recent call last):
File "nlu/nlu_classifiers/elmo_scratch.py", line 67, in <module>
main()
File "nlu/nlu_classifiers/elmo_scratch.py", line 61, in main
model.predict(model_input, batch_size=1)
File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/keras/engine/training.py", line 1169, in predict
steps=steps)
File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/keras/engine/training_arrays.py", line 294, in predict_loop
batch_outs = f(ins_batch)
File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2715, in __call__
return self._call(inputs)
File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2675, in _call
fetched = self._callable_fn(*array_vals)
File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1439, in __call__
run_metadata_ptr)
File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 528, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: input must be a vector, got shape: []
[[{{node lambda_1/module_apply_default/StringSplit}}]]
Почему это? И есть ли способ предсказать на одном примере без использования np.repeat
? Это не большая проблема, потому что это в основном та же скорость, но это меня немного раздражало.