Tensorflow: Tensor должен быть из того же графика, что и Tensor - PullRequest
0 голосов
/ 02 декабря 2018

Первые шаги в тензорном потоке, я пытаюсь обучить модель DNN для классификации изображений.

Мой текущий код:

folder_path = Path('cropped_images/cropped')
df['filename'] = df['tag_id'].map(lambda tag: str(folder_path / (tag + '.png')))

def database_input_fn():

    def parse_image(filename, label):
        image_decoded = tf.image.decode_png(tf.read_file(filename), channels=3)
        image_resized = tf.image.resize_images(image_decoded, [64, 64])
        label = label == 'large vehicle'
        return image_resized, label

    filenames = tf.constant(df['filename'])
    labels = tf.constant(df['general_class'])
    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
    dataset = dataset.map(parse_image)
    dataset = dataset.shuffle()
    dataset = dataset.batch(32)
    dataset = dataset.repeat()
    return dataset

images_fc = tf.feature_column.numeric_column('image', shape=[64, 64, 3])

estimator = tf.estimator.DNNClassifier(feature_columns=[images_fc],
                                     hidden_units=[32, 32, 32, 32])
metrics = estimator.train(lambda : dataset, steps=10000)

Где df означает pandas.DataFrame, содержащийпути к изображениям и соответствующие им метки.Изображения хранятся на диске по указанному выше пути к папке.

Я получаю следующую ошибку:

ValueError: Tensor("IteratorV2:0", shape=(), dtype=resource) must be from the same graph as Tensor("BatchDatasetV2_4:0", shape=(), dtype=variant).

Чего мне не хватает?Почему не все строится на одном графике?

1 Ответ

0 голосов
/ 02 декабря 2018

Я думаю, что

metrics = estimator.train(lambda : dataset, steps=10000)

может быть проблемой.Если вы проверите аргументы для поезда оценки , input_fn создаст и вернет набор данных, что означает, что для оценщика и функции ввода создается новый график.В вашем случае вы уже создали этот график за пределами этой области.Возможно, изменение вашего кода на что-то вроде:

metrics = estimator.train(input_fn=database_input_fn, steps=10000) 

может решить эту проблему!

...