Проблема с чтением и получением пакета с 2d массивом данных tfrecords - PullRequest
0 голосов
/ 21 апреля 2020

У меня проблемы с чтением Tfrecords файла при получении пакета.

Во-первых, это мой read_tfrecords.py файл.

import tensorflow as tf
import os
from glob import glob
import numpy as np


def serialize_example(batch, list1, list2):
    filename = "./train_set.tfrecords"
    writer = tf.io.TFRecordWriter(filename)

    for i in range(batch):
        feature = {}
        feature1 = np.load(list1[i])
        feature2 = np.load(list2[i])
        print('feature1 shape {} feature2 shape {}'.format(feature1.shape, feature2.shape)) 
        feature['input'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature1.flatten()))
        feature['target'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature2.flatten()))

        features = tf.train.Features(feature=feature)
        example = tf.train.Example(features=features)
        serialized = example.SerializeToString()
        writer.write(serialized)
        print("{}th input {} target {} finished".format(i, list1[i], list2[i]))



list_inp = sorted(glob('./input/2d_magnitude/*'))
list_tar = sorted(glob('./target/2d_magnitude/*'))


print(len(list_inp))
serialize_example(len(list_inp), list_inp, list_tar)

Мой ввод и цель фигуры 2d массив. Таким образом, файл Tfrecords равен [number_of_dataset, x, y]. Примерно 100 000 набор данных был успешно сохранен как файл Tfrecords .

И у меня проблема, когда я читаю файл Tfrecords , чтобы получить пакет. Это мой код read_tfrecords.py:

import tensorflow as tf
import os
import numpy as np

shuffle_buffer_size = 50000
batch_size = 10
record_file = '/data2/dataset/tfrecords/train_set.tfrecords'

raw_dataset = tf.data.TFRecordDataset(record_file)
print('raw_dataset', raw_dataset) # ==> raw_dataset <TFRecordDatasetV2 shapes: (), types: tf.string>

raw_dataset = raw_dataset.repeat()
print('repeat', raw_dataset) # ==> repeat <RepeatDataset shapes: (), types: tf.string>

raw_dataset = raw_dataset.shuffle(shuffle_buffer_size)
print('shuffle', raw_dataset) # ==> shuffle <ShuffleDataset shapes: (), types: tf.string>

raw_dataset = raw_dataset.batch(batch_size, drop_remainder=True)
print('batch', raw_dataset) # ==> batch <BatchDataset shapes: (10,), types: tf.string>

raw_example = next(iter(raw_dataset)) 

parsed = tf.train.Example.FromString(raw_example.numpy()) # ==> read_tfrecords.py:25: RuntimeWarning: Unexpected end-group tag: Not all data was converted

print('parsed', parsed) # ==> ''

input = parsed.features.feature['input'].float_list.value
print('input', input) # ==> []
target = parsed.features.feature['target'].float_list.value
print('target', target) # ==> []

Вот результаты из кода:

raw_dataset <TFRecordDatasetV2 shapes: (), types: tf.string>
repeat <RepeatDataset shapes: (), types: tf.string>
shuffle <ShuffleDataset shapes: (), types: tf.string>
batch <BatchDataset shapes: (10,), types: tf.string>
read_tfrecords.py:25: RuntimeWarning: Unexpected end-group tag: Not all data was converted
  parsed = tf.train.Example.FromString(raw_example.numpy())
parsed
input []
target []

В результате мне интересно, как я получаю партию из Tfrecords файл для обучения. Не могли бы вы дать совет?

...