API tf.data читает файлы TFRecord - PullRequest
       11

API tf.data читает файлы TFRecord

0 голосов
/ 26 октября 2018

Я пытаюсь использовать API tf.data для чтения файла TFRecord.

import tensorflow as tf
from PIL import Image
import numpy as np
import os

def train_input_fn():
    filenames = ["mytrain.tfrecords"]
    dataset = tf.data.TFRecordDataset(filenames)

    def parser(record):
        keys_to_features = {
            "image_data": tf.FixedLenFeature((), tf.string, default_value=""),
            "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
            "label": tf.FixedLenFeature((), tf.int64,
                                        default_value=tf.zeros([], dtype=tf.int64)),
        }
        parsed = tf.parse_single_example(record, keys_to_features)

        image = tf.decode_jpeg(parsed["image_data"])
        image = tf.reshape(image, [128, 128, 3])
        label = tf.cast(parsed["label"], tf.int32)

        return {"image_data": image, "date_time": parsed["date_time"]}, label

    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(32)
    dataset = dataset.repeat(1)
    iterator = dataset.make_one_shot_iterator()

    features, labels = iterator.get_next()
    return features, labels

output = train_input_fn()

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord = coord)
    for i in range(230):
        image, label = sess.run(output)
        img = Image.fromarray(image, 'RGB')
        img.save(cwd+str(i) + '_''Label_'+str(l)+'.jpg')
        print(image, label)
    coord.request_stop()
    coord.join(threads)

Traceback (последний вызов был последним): файл "E: /Tensorflow/Wenshan_Cai_Nanoletters/tf_data.py", строка34, в output = train_input_fn () Файл "E: /Tensorflow/Wenshan_Cai_Nanoletters/tf_data.py", строка 25, в train_input_fn Ошибка типа: ожидаемый тип int64, вместо него получен тип 'str'.

1 Ответ

0 голосов
/ 26 октября 2018

Примечание TypeError: Expected int64, got '' of type 'str' instead из вашего журнала ошибок.У вас есть ошибка в вашем коде.

Ошибка

В следующей строке:

"date_time": tf.FixedLenFeature((), tf.int64, default_value=""),

Значение по умолчанию для переменной типа tf.int64задается в виде строки "".

Исправление

Итак, предположим, что ожидаемое значение по умолчанию равно 0, тогда вам следует изменить строку на:

"date_time": tf.FixedLenFeature((), tf.int64, default_value=0),

Надеюсь, это поможет.

...