Проблема оценщика Tensorflow с наборами данных - PullRequest
0 голосов
/ 27 ноября 2018

У меня странная проблема с оценщиками TF, и я пытаюсь использовать tf.Dataset в своей функции ввода.

Во-первых, моя модель выглядит следующим образом:

    model = tf.estimator.DNNClassifier(
        feature_columns=my_feature_column,
        hidden_units=[hidden_layers, hidden_layers],
        n_classes=n_classes)

имой характерный столбик выглядит так:

    my_feature_column = [tf.feature_column.numeric_column(key='image', shape=[32, 32, 3])]

Теперь, если я тренируюсь так, все работает нормально, и тренировка проходит через пару секунд:

    model.train(
        input_fn=tf.estimator.inputs.numpy_input_fn(
            dict({'image':X_train}),
            y_train,
            shuffle=True),
        steps=nb_epoch)

Но когда я пытаюсьчтобы добавить tf.Datasets в функцию ввода, потребуется вечное выполнение:

def input_fn(features, labels, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices(({'image':features}, labels))
    return dataset.shuffle(1000).batch(batch_size).repeat()

model.train(
    input_fn=lambda:input_fn(X_train, y_train, batch_size),
    steps=nb_epoch)

Кто-нибудь может увидеть, что я делаю неправильно, пожалуйста?Это должно быть идентично, верно?

Спасибо, Пол

1 Ответ

0 голосов
/ 03 декабря 2018

Ваш набор данных повторяется бесконечно, и максимальное число итераций по умолчанию отсутствует, поэтому тензор потока не знает, когда остановиться.

Замените строку с надписью return dataset.shuffle(1000).batch(batch_size).repeat() на что-то вроде return dataset.shuffle(1000).batch(batch_size).repeat(10), которая будет тренироватьсяна 10 эпох, и все будет хорошо.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...