Чтение из файлов .tfrecord с использованием tf.data.Dataset - PullRequest
0 голосов
/ 08 сентября 2018

Я хочу прочитать набор данных, сгенерированный этим кодом с API tf.data.Dataset. Репо показывает, что оно написано так:

def image_to_tfexample(image_data, image_format, height, width, class_id):
  return tf.train.Example(features=tf.train.Features(feature={
      'image/encoded': bytes_feature(image_data),
      'image/format': bytes_feature(image_format),
      'image/class/label': int64_feature(class_id),
      'image/height': int64_feature(height),
      'image/width': int64_feature(width),
  }))

с (encoded byte-string, b'png', 32, 32, label) в качестве параметров.

Итак, чтобы прочитать файл .tfrecord, формат данных должен быть:

example_fmt = {
    'image/encoded': tf.FixedLenFeature((), tf.string, ""),
    'image/format': tf.FixedLenFeature((), tf.string, ""),
    'image/class/label': tf.FixedLenFeature((), tf.int64, -1),
    'image/height': tf.FixedLenFeature((), tf.int64, -1),
    'image/width': tf.FixedLenFeature((), tf.int64, -1)
}
parsed = tf.parse_single_example(example, example_fmt)
image = tf.decode_raw(parsed['image/encoded'], out_type=tf.uint8)

Но это не работает. После считывания набор данных становится пустым, и генерация итератора с ним поднимает OutOfRangeError: End of sequence.

Короткий сценарий Python для воспроизведения можно найти здесь . Я изо всех сил пытаюсь найти точную документацию или примеры для этой проблемы.

Ответы [ 2 ]

0 голосов
/ 08 сентября 2018

Я все еще изучаю использование TensorFlow и tfrecordfile, поэтому я не владею этими вещами, но я нашел это руководство , которое было полезно в моем случае и могло бы быть полезным и для вас.

0 голосов
/ 08 сентября 2018

Я не могу проверить ваш код, потому что у меня нет файла train.tfrecords. Этот код создает пустой набор данных?

dataset = tf.data.TFRecordDataset('train.tfrecords')
dataset = dataset.map(parse_fn)
itr = dataset.make_one_shot_iterator()

with tf.Session() as sess:
    while True:
        try:
            print(sess.run(itr.get_next()))
        except tf.errors.OutOfRangeError:
            break

Если это дает вам ошибку, пожалуйста, дайте мне знать, какая строка ее выдаёт.

...