Как читать и записывать tfrecord файлы 2d массива - PullRequest
0 голосов
/ 19 октября 2019

Я хочу сделать 2d массив размером (n, 3) в tfrecord file и прочитать его.

Код, который я написал для создания tfrecord file, равен

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
example = tf.train.Example(
      features=tf.train.Features(
          feature={
              'arry_x':_float_feature(array[:,0]),
              'arry_y':_float_feature(array[:,1]),
              'arry_z':_float_feature(array[:,2])}
         )
      )

with tf.compat.v1.python_io.TFRecordWriter(file_name) as writer:
    writer.write(example.SerializeToString())

И я попытался прочитать файл с TFRecordReader

def get_tfrecord_feature():
    return{
        'arry_x': tf.compat.v1.io.FixedLenFeature([], tf.float32),
        'arry_y': tf.compat.v1.io.FixedLenFeature([], tf.float32),
        'arry_z': tf.compat.v1.io.FixedLenFeature([], tf.float32)
    }
filenames = [file_name, file_name2, ...]
file_name_queue = tf.train.string_input_producer(filenames)

reader = tf.TFRecordReader()
_, serialized_example = reader.read(file_name_queue)

data = tf.compat.v1.io.parse_single_example(serialized_example, features=get_tfrecord_feature())

x = data['arry_x']
y = data['arry_y']
z = data['arry_z']

x, y, z = tf.train.batch([x, y, z], batch_size=1)

И я использовал tf.Session для проверки кода

with tf.compat.v1.Session() as sess:
    print(sess.run(x))

Кодработает без ошибок, но сеанс не печатает никакого значения. Я думаю, что способ прочитать tfrecord file был неправильным. Кто-нибудь может мне помочь?

Ответы [ 2 ]

0 голосов
/ 19 октября 2019

Спасибо за совет donglinjy, я исправил свой код здесь

def get_tfrecord_feature():
    return{
        'arry_x': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
        'arry_y': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
        'arry_z': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32)
    }

и здесь.

with tf.compat.v1.Session() as sess:
    coord=tf.train.Coordinator()
    threads=tf.train.start_queue_runners(coord=coord)
    print(sess.run(x))

Теперь это работает.

0 голосов
/ 19 октября 2019

Я думаю, что вы должны добавить длину списка, которая в вашем случае равна array.shape [0] , как показано ниже, к определению функций при разборе записи tf.

def get_tfrecord_feature():
    return{
        'arry_x': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
        'arry_y': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32),
        'arry_z': tf.compat.v1.io.FixedLenFeature([array.shape[0]], tf.float32)
    }

Вы можете оставить форму как [], если для FixedLenFeature есть только один элемент. https://tensorflow.org/versions/r1.15/api_docs/python/tf/io/FixedLenFeature

...