Я пытаюсь реализовать тензорные потоки, предоставляя API высокого уровня, в частности, базовый классификатор. Однако, пытаясь обучить модель, я получаю следующее
Ошибка:
NotFoundError (see above for traceback): Key baseline/bias not found in checkpoint
[[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
Код:
import tensorflow as tf
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
def digit_cross():
# Number of classes, one class for each of 10 digits.
num_classes = 10
digit = datasets.load_digits()
x = digit.data
y = digit.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.3, random_state=42)
y_train_index = np.arange(y_train.size)
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": np.array(x_train)},
y=np.array(y_train),
num_epochs=None,
shuffle=False)
# Build BaselineClassifier
classifier = tf.estimator.BaselineClassifier(n_classes=num_classes,
model_dir="./checkpoints_tutorial17-1/")
# Fit model.
classifier.train(train_input_fn)
digit_cross()