Несоответствие формы ввода при выводе из сохраненной и загруженной модели BERT - PullRequest
0 голосов
/ 11 февраля 2020

Я использую модель BERT для классификации бинарных предложений. Я хочу сделать вывод из моей обученной модели, которую я смог экспортировать, используя serving_input_receiver_fn(), как показано ниже:

def serving_input_receiver_fn():
  feature_spec = {
      "input_ids" : tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
      "input_mask" : tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
      "segment_ids" : tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
      "label_ids" :  tf.FixedLenFeature([], tf.int64)
  }
  input_example_sentence = tf.placeholder(dtype=tf.string,
                                         shape=[None],
                                         name='Sentence')
  print("Input example shape:",input_example_sentence.shape)
  receiver_tensors = {'Sentence': input_example_sentence}
  features = tf.parse_example(input_example_sentence, feature_spec)
  return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

export_path = 'home/ec2-user/SageMaker/Final/new_bert_classification'
estimator._export_to_tpu = False 
estimator.export_saved_model(export_dir_base = export_path,serving_input_receiver_fn = serving_input_receiver_fn)

ВЫХОД:

Введите пример формы :( ?,)

Мне удалось загрузить модель, используя predictor.from_saved_model(export_dir) из класса предикторов в тензорном потоке, как показано ниже:

export_dir = 'home/ec2-user/SageMaker/Final/new_bert_classification/1581391407/'
from tensorflow.contrib import predictor
predict_fn = predictor.from_saved_model(export_dir)

Я также позаботился об изменении формы введите предложения для прогнозирования по форме, как требуется для функции выше, следующим образом:

pred_sentences = list(val['Sentence'])
pred_sentences_array = np.asarray(pred_sentences)
print("prediction sentence shape:",pred_sentences_array.shape)

ВЫХОД:

форма предложения прогнозирования: (24,)

Однако, когда я пытаюсь сделать вывод из функции predict_fn, я сталкиваюсь с ошибкой несоответствия массива.

pred = predict_fn({'Sentence': [pred_sentences_array]})['output']

OUTPUT:

---------------------------------------------------------------------------
 ValueError                                Traceback (most recent call
last) <ipython-input-172-77538f68e8f9> in <module>()
----> 1 pred = predict_fn({'Sentence': [pred_sentences_array]})['output']

~/anaconda3/envs/amazonei_tensorflow_p36/lib/python3.6/site-packages/tensorflow/contrib/predictor/predictor.py
in __call__(self, input_dict)
      75       if value is not None:
      76         feed_dict[self.feed_tensors[key]] = value
 ---> 77     return self._session.run(fetches=self.fetch_tensors, feed_dict=feed_dict)

 ~/anaconda3/envs/amazonei_tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py
 in run(self, fetches, feed_dict, options, run_metadata)
     927     try:
     928       result = self._run(None, fetches, feed_dict, options_ptr,
 --> 929                          run_metadata_ptr)
     930       if run_metadata:
     931         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

 ~/anaconda3/envs/amazonei_tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py
 in _run(self, handle, fetches, feed_dict, options, run_metadata)   
 1126                              'which has shape %r' %    1127      
> (np_val.shape, subfeed_t.name,
> -> 1128                               str(subfeed_t.get_shape())))    1129           if not self.graph.is_feedable(subfeed_t):    1130      
 raise ValueError('Tensor %s may not be fed.' % subfeed_t)

 ValueError: Cannot feed value of shape (1, 24) for Tensor
 'Sentence:0', which has shape '(?,)'

Эта ошибка может быть вызвана [] вокруг pred_sentences_array в predict_fn({'Sentence': [pred_sentences_array]}), но если я уберу эти скобки, произойдет следующая ошибка:

pred = predict_fn({'Sentence': pred_sentences_array})['output'] 

OUTPUT

 ---------------------------------------------------------------------------
 InvalidArgumentError                      Traceback (most recent call
 last)
 ~/anaconda3/envs/amazonei_tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py
 in _do_call(self, fn, *args)    1333     try:
 -> 1334       return fn(*args)    1335     except errors.OpError as e:

 ~/anaconda3/envs/amazonei_tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py
 in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata) 
 1318       return self._call_tf_sessionrun(
 -> 1319           options, feed_dict, fetch_list, target_list, run_metadata)    1320 

 ~/anaconda3/envs/amazonei_tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py
 in _call_tf_sessionrun(self, options, feed_dict, fetch_list,
 target_list, run_metadata)    1406         self._session, options,
 feed_dict, fetch_list, target_list,
 -> 1407         run_metadata)    1408 

 InvalidArgumentError: Could not parse example input, value: '· Study
 Completion information.Study Completion or in the event of early
 withdrawal'    
 [[{{node ParseExample/ParseExample}}]]

 During handling of the above exception, another exception occurred:

 InvalidArgumentError                      Traceback (most recent call
 last) <ipython-input-173-12847c81b368> in <module>()
 ----> 1 pred = predict_fn({'Sentence': pred_sentences_array})['output']

Я попытался выполнить поиск в стеке, но не смог найти возможного решения. Я мог пропустить их, но это очень маловероятно. Надеюсь, что объяснение является оптимальным для понимания моей проблемы.

...