Как декодировать список байтов, хранящихся в tfrecords, используя TensorFlow? - PullRequest
0 голосов
/ 05 ноября 2018

Я сохранил различное количество изображений в виде списка байтов:

img.append(trajectory_step['img'].tostring())
feature['img'] = _bytes_feature(img)
...
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())

Я также сохранил количество изображений для последующего декодирования (experiment_length).

Теперь я не могу декодировать изображения следующим образом:

features = {
'img': tf.VarLenFeature(tf.string),
...
}
parsed_features = tf.parse_single_example(example_proto, features)
img = tf.decode_raw(parsed_features['img'], out_type=tf.uint8)
img = tf.reshape(img, tf.stack([experiment_length, 120, 160, 3]))

Что приводит к следующей ошибке:

TypeError: Ожидаемая строка передана параметру 'bytes' операции 'DecodeRaw', получил вместо типа «SparseTensor».

Если я решу использовать tf.FixedLenFeature, я получаю следующую ошибку:

tenorflow.python.framework.errors_impl.InvalidArgumentError: Name: , Ключ: img, Индекс: 0. Количество байтовых значений! = Ожидается. Размер значения: 5, но форма вывода: []

Как правильно декодировать список байтов? И: tf.VarLenFeature в этом случае правильно или я должен использовать tf.FixedLenFeature?

Спасибо

...