Tensorflow Estimator InvalidArgumentError - PullRequest
       5

Tensorflow Estimator InvalidArgumentError

0 голосов
/ 08 января 2019

Я пытаюсь найти способ найти и исправить ошибку в моем коде TF. Фрагмент кода ниже успешно обучает модель, но генерирует следующую ошибку при вызове последней строки (model.evaluate (input_fn)):

InvalidArgumentError: Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
/var/folders/kx/y9syv3f91b1c6tzt3fgzc7jm0000gn/T/tmp_r6c94ni/model.ckpt-667.data-00000-of-00001; Invalid argument
     [[node save/RestoreV2 (defined at ../text_to_topic/train/nn/nn_tf.py:266)  = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Caused by op 'save/RestoreV2', defined at:
  File "/Users/foo/miniconda3/envs/tt/lib/python3.6/runpy.py", line 193, in _run_module_as_main

Точно такой же код работает при использовании с набором данных MNIST, но не работает при использовании с моим собственным набором данных. Как я могу отладить это или что может быть причиной. Кажется, что графики не совпадают после того, как модель восстановлена ​​из контрольной точки, но я не уверен, как поступить, чтобы исправить это. Я пробовал с TF версии 1.11 и 1.13

model = tf.estimator.Estimator(get_nn_model_fn(num_classes))

# Define the input function for training
input_fn = tf.estimator.inputs.numpy_input_fn(
    x=X_train, y=y_train,
    batch_size=batch_size,
    num_epochs=None, shuffle=True)

# Train the Model
model.train(input_fn, steps=num_steps)

# Evaluate the Model
# Define the input function for evaluating
input_fn = tf.estimator.inputs.numpy_input_fn(
    x=X_test, y=y_test,
    batch_size=batch_size, shuffle=False)

# Use the Estimator 'evaluate' method
e = model.evaluate(input_fn) 

1 Ответ

0 голосов
/ 08 января 2019

Эта ошибка часто возникает при изменении какой-либо части графика, например, измените размер скрытых слоев или удалите / добавьте несколько слоев, и оценщик попытается загрузить ранее контрольные точки. У вас есть два варианта решения проблемы:

1) Изменить каталог модели (model_dir):

config = tf.estimator.RunConfig(model_dir='./NEW_PATH/', ) # new path
model_estimator = tf.estimator.Estimator(model_fn=model_fn, config=config)

2) Удалить ранее сохраненные контрольные точки в каталоге модели (model_dir).


Вы уверены, что график не тронут?

Убедитесь, что новый набор данных имеет тот же Data-type, что и раньше. Если вы ранее загрузили числа с плавающей запятой для входных данных, в новом наборе данных они также должны быть числами с плавающей запятой.

...