Я использую модель BERT для классификации бинарных предложений. Я хочу сделать вывод из моей обученной модели, которую я смог экспортировать, используя serving_input_receiver_fn()
, как показано ниже:
def serving_input_receiver_fn():
feature_spec = {
"input_ids" : tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
"input_mask" : tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
"segment_ids" : tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
"label_ids" : tf.FixedLenFeature([], tf.int64)
}
input_example_sentence = tf.placeholder(dtype=tf.string,
shape=[None],
name='Sentence')
print("Input example shape:",input_example_sentence.shape)
receiver_tensors = {'Sentence': input_example_sentence}
features = tf.parse_example(input_example_sentence, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
export_path = 'home/ec2-user/SageMaker/Final/new_bert_classification'
estimator._export_to_tpu = False
estimator.export_saved_model(export_dir_base = export_path,serving_input_receiver_fn = serving_input_receiver_fn)
ВЫХОД:
Введите пример формы :( ?,)
Мне удалось загрузить модель, используя predictor.from_saved_model(export_dir)
из класса предикторов в тензорном потоке, как показано ниже:
export_dir = 'home/ec2-user/SageMaker/Final/new_bert_classification/1581391407/'
from tensorflow.contrib import predictor
predict_fn = predictor.from_saved_model(export_dir)
Я также позаботился об изменении формы введите предложения для прогнозирования по форме, как требуется для функции выше, следующим образом:
pred_sentences = list(val['Sentence'])
pred_sentences_array = np.asarray(pred_sentences)
print("prediction sentence shape:",pred_sentences_array.shape)
ВЫХОД:
форма предложения прогнозирования: (24,)
Однако, когда я пытаюсь сделать вывод из функции predict_fn
, я сталкиваюсь с ошибкой несоответствия массива.
pred = predict_fn({'Sentence': [pred_sentences_array]})['output']
OUTPUT:
---------------------------------------------------------------------------
ValueError Traceback (most recent call
last) <ipython-input-172-77538f68e8f9> in <module>()
----> 1 pred = predict_fn({'Sentence': [pred_sentences_array]})['output']
~/anaconda3/envs/amazonei_tensorflow_p36/lib/python3.6/site-packages/tensorflow/contrib/predictor/predictor.py
in __call__(self, input_dict)
75 if value is not None:
76 feed_dict[self.feed_tensors[key]] = value
---> 77 return self._session.run(fetches=self.fetch_tensors, feed_dict=feed_dict)
~/anaconda3/envs/amazonei_tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py
in run(self, fetches, feed_dict, options, run_metadata)
927 try:
928 result = self._run(None, fetches, feed_dict, options_ptr,
--> 929 run_metadata_ptr)
930 if run_metadata:
931 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~/anaconda3/envs/amazonei_tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py
in _run(self, handle, fetches, feed_dict, options, run_metadata)
1126 'which has shape %r' % 1127
> (np_val.shape, subfeed_t.name,
> -> 1128 str(subfeed_t.get_shape()))) 1129 if not self.graph.is_feedable(subfeed_t): 1130
raise ValueError('Tensor %s may not be fed.' % subfeed_t)
ValueError: Cannot feed value of shape (1, 24) for Tensor
'Sentence:0', which has shape '(?,)'
Эта ошибка может быть вызвана [] вокруг pred_sentences_array
в predict_fn({'Sentence': [pred_sentences_array]})
, но если я уберу эти скобки, произойдет следующая ошибка:
pred = predict_fn({'Sentence': pred_sentences_array})['output']
OUTPUT
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call
last)
~/anaconda3/envs/amazonei_tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py
in _do_call(self, fn, *args) 1333 try:
-> 1334 return fn(*args) 1335 except errors.OpError as e:
~/anaconda3/envs/amazonei_tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py
in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
1318 return self._call_tf_sessionrun(
-> 1319 options, feed_dict, fetch_list, target_list, run_metadata) 1320
~/anaconda3/envs/amazonei_tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py
in _call_tf_sessionrun(self, options, feed_dict, fetch_list,
target_list, run_metadata) 1406 self._session, options,
feed_dict, fetch_list, target_list,
-> 1407 run_metadata) 1408
InvalidArgumentError: Could not parse example input, value: '· Study
Completion information.Study Completion or in the event of early
withdrawal'
[[{{node ParseExample/ParseExample}}]]
During handling of the above exception, another exception occurred:
InvalidArgumentError Traceback (most recent call
last) <ipython-input-173-12847c81b368> in <module>()
----> 1 pred = predict_fn({'Sentence': pred_sentences_array})['output']
Я попытался выполнить поиск в стеке, но не смог найти возможного решения. Я мог пропустить их, но это очень маловероятно. Надеюсь, что объяснение является оптимальным для понимания моей проблемы.