У меня странная проблема с оценщиками TF, и я пытаюсь использовать tf.Dataset в своей функции ввода.
Во-первых, моя модель выглядит следующим образом:
model = tf.estimator.DNNClassifier(
feature_columns=my_feature_column,
hidden_units=[hidden_layers, hidden_layers],
n_classes=n_classes)
имой характерный столбик выглядит так:
my_feature_column = [tf.feature_column.numeric_column(key='image', shape=[32, 32, 3])]
Теперь, если я тренируюсь так, все работает нормально, и тренировка проходит через пару секунд:
model.train(
input_fn=tf.estimator.inputs.numpy_input_fn(
dict({'image':X_train}),
y_train,
shuffle=True),
steps=nb_epoch)
Но когда я пытаюсьчтобы добавить tf.Datasets в функцию ввода, потребуется вечное выполнение:
def input_fn(features, labels, batch_size):
dataset = tf.data.Dataset.from_tensor_slices(({'image':features}, labels))
return dataset.shuffle(1000).batch(batch_size).repeat()
model.train(
input_fn=lambda:input_fn(X_train, y_train, batch_size),
steps=nb_epoch)
Кто-нибудь может увидеть, что я делаю неправильно, пожалуйста?Это должно быть идентично, верно?
Спасибо, Пол