Я пытаюсь использовать готовый оценщик tf.estimator.DNNClassifier
для использования в наборе данных MNIST. Я загружаю набор данных из tensorflow_dataset
.
Я выполняю следующие четыре шага: сначала строим конвейер набора данных и определяем функцию ввода:
## Step 1
mnist, info = tfds.load('mnist', with_info=True)
ds_train_orig, ds_test = mnist['train'], mnist['test']
def train_input_fn(dataset, batch_size):
dataset = dataset.map(lambda x:({'image-pixels':tf.reshape(x['image'], (-1,))},
x['label']))
return dataset.shuffle(1000).repeat().batch(batch_size)
Затем, на шаге 2, я определяю столбец объектов с помощью одной клавиши и форму 784:
## Step 2:
image_feature_column = tf.feature_column.numeric_column(key='image-pixels',
shape=(28*28))
image_feature_column
NumericColumn(key='image-pixels', shape=(784,), default_value=None, dtype=tf.float32, normalizer_fn=None)
Шаг 3, я описал оценку следующим образом:
## Step 3:
dnn_classifier = tf.estimator.DNNClassifier(
feature_columns=image_feature_column,
hidden_units=[16, 16],
n_classes=10)
И, наконец, шаг 4, используя оценщик, вызвав метод .train()
:
## Step 4:
dnn_classifier.train(
input_fn=lambda:train_input_fn(ds_train_orig, batch_size=32),
#lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
steps=20)
Но это повторяется в следующей ошибке. Похоже, проблема возникла из набора данных.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-21-95736cd65e45> in <module>
2 dnn_classifier.train(
3 input_fn=lambda: train_input_fn(ds_train_orig, batch_size=32),
----> 4 steps=20)
~/anaconda3/envs/tf2.0-beta/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, ctx, accept_symbolic_tensors, accept_composite_tensors)
1183 graph = get_default_graph()
1184 if not graph.building_function:
-> 1185 raise RuntimeError("Attempting to capture an EagerTensor without "
1186 "building a function.")
1187 return graph.capture(value, name=name)
RuntimeError: Attempting to capture an EagerTensor without building a function.