Я использую наборы данных TensorFlow 2.0 для поддержки функции подгонки моей модели. Вот код:
def build_model(self):
self.g_Model = Sequential()
self.g_Model.add(Embedding(self.g_Max_features, output_dim=256))
self.g_Model.add(LSTM(128))
self.g_Model.add(Dropout(0.5))
self.g_Model.add(Dense(1, activation='sigmoid'))
self.g_Model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
def train_model(self, filenames):
lstm_feature_description = {
'X': tf.io.FixedLenFeature(CONFIG.g_keras_lstm_max_document_length, tf.float32),
'y': tf.io.FixedLenFeature((), tf.int64),
}
def _parse_lstm_function(example_proto):
return tf.io.parse_single_example(serialized=example_proto, features=lstm_feature_description)
self.build_model()
# Start Preparing The Data
raw_lstm_dataset = tf.data.TFRecordDataset(CONFIG.g_record_file_lstm)
parsed_lstm_dataset = raw_lstm_dataset.map(_parse_lstm_function)
parsed_lstm_dataset = parsed_lstm_dataset.shuffle(CONFIG.g_shuffle_s).batch(CONFIG.g_Batch_size)
self.g_Model.fit(parsed_lstm_dataset, epochs=2)
Но я получаю следующую ошибку:
Traceback (most recent call last):
File "keras_lstm_v2.py", line 79, in train_model
1/Unknown - 0s 0s/step self.g_Model.fit(parsed_lstm_dataset, epochs=2)
File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 728, in fit
use_multiprocessing=use_multiprocessing)
File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 324, in fit
total_epochs=epochs)
File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 123, in run_one_epoch
batch_outs = execution_function(iterator)
File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 86, in execution_function
distributed_function(input_fn))
File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 457, in __call__
result = self._call(*args, **kwds)
File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 503, in _call
self._initialize(args, kwds, add_initializers_to=initializer_map)
File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 408, in _initialize
*args, **kwds))
File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\function.py", line 1848, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\function.py", line 2150, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\function.py", line 2041, in _create_graph_function
capture_by_value=self._capture_by_value),
File "venv_tf_new\lib\site-packages\tensorflow_core\python\framework\func_graph.py", line 915, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "venv_tf_new\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 358, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 66, in distributed_function
model, input_iterator, mode)
File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 118, in _prepare_feed_values
inputs = [inputs[key] for key in model._feed_input_names]
File "venv_tf_new\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 118, in <listcomp>
inputs = [inputs[key] for key in model._feed_input_names]
KeyError: 'embedding_input'
Я видел этот поток , однако это не проясняет вещидля меня. Насколько я понял, есть проблема с загруженными данными, но в соответствии с документацией для наборов данных она должна работать из коробки, поэтому я не мог понять, как ее исправить.
Любая помощь приветствуется,Спасибо!