Ошибка при реализации высокоуровневого API Tenorflow - PullRequest
0 голосов
/ 03 мая 2018

Я пытаюсь реализовать тензорные потоки, предоставляя API высокого уровня, в частности, базовый классификатор. Однако, пытаясь обучить модель, я получаю следующее

Ошибка:

NotFoundError (see above for traceback): Key baseline/bias not found in checkpoint
     [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Код:

import tensorflow as tf
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split

def digit_cross():
    # Number of classes, one class for each of 10 digits.
    num_classes = 10

    digit = datasets.load_digits()
    x = digit.data
    y = digit.target
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.3, random_state=42)
    y_train_index = np.arange(y_train.size)

    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={"x": np.array(x_train)},
        y=np.array(y_train),
        num_epochs=None,
        shuffle=False)

    # Build BaselineClassifier
    classifier = tf.estimator.BaselineClassifier(n_classes=num_classes,
                                                 model_dir="./checkpoints_tutorial17-1/")

    # Fit model.
    classifier.train(train_input_fn)

digit_cross()

1 Ответ

0 голосов
/ 03 мая 2018

Похоже, что у вас есть контрольная точка в model_dir="./checkpoints_tutorial17-1/", которая принадлежит другой модели, а не BaselineClassifier. Если быть точным, у вас есть файл контрольной точки и файлы model.ckpt- * в этой папке.

Как описано в тензорном потоке:

  • model_dir: каталог для сохранения параметров модели, графика и т. Д. Это также можно использовать для загрузки контрольных точек из каталога в оценщик для продолжения обучения ранее сохраненной модели. Если PathLike объект, путь будет разрешен. Если None, model_dir в config будет использоваться, если установлен. Если оба установлены, они должны быть одинаковыми. Если оба значения отсутствуют, будет использоваться временный каталог.

Здесь BaselineClassifier сначала построит график, который использует baseline/bias. Затем он обнаруживает, что в model_dir есть предыдущая контрольная точка. Он попытается загрузить эту контрольную точку, и вы должны увидеть информацию (если вы сделали tf.logging.set_verbosity(tf.logging.INFO)), говорящую что-то вроде

"INFO:tensorflow:Restoring parameters from .../checkpoints_tutorial17-1\model.ckpt-..."

Поскольку эта контрольная точка в model_dir не из BaselineClassifier, она не будет иметь baseline/bias. BaselineClassifier не может найти его и, следовательно, выдаст ошибку.

...