У меня проблемы с чтением Tfrecords файла при получении пакета.
Во-первых, это мой read_tfrecords.py
файл.
import tensorflow as tf
import os
from glob import glob
import numpy as np
def serialize_example(batch, list1, list2):
filename = "./train_set.tfrecords"
writer = tf.io.TFRecordWriter(filename)
for i in range(batch):
feature = {}
feature1 = np.load(list1[i])
feature2 = np.load(list2[i])
print('feature1 shape {} feature2 shape {}'.format(feature1.shape, feature2.shape))
feature['input'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature1.flatten()))
feature['target'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature2.flatten()))
features = tf.train.Features(feature=feature)
example = tf.train.Example(features=features)
serialized = example.SerializeToString()
writer.write(serialized)
print("{}th input {} target {} finished".format(i, list1[i], list2[i]))
list_inp = sorted(glob('./input/2d_magnitude/*'))
list_tar = sorted(glob('./target/2d_magnitude/*'))
print(len(list_inp))
serialize_example(len(list_inp), list_inp, list_tar)
Мой ввод и цель фигуры 2d массив. Таким образом, файл Tfrecords равен [number_of_dataset, x, y]
. Примерно 100 000 набор данных был успешно сохранен как файл Tfrecords .
И у меня проблема, когда я читаю файл Tfrecords , чтобы получить пакет. Это мой код read_tfrecords.py
:
import tensorflow as tf
import os
import numpy as np
shuffle_buffer_size = 50000
batch_size = 10
record_file = '/data2/dataset/tfrecords/train_set.tfrecords'
raw_dataset = tf.data.TFRecordDataset(record_file)
print('raw_dataset', raw_dataset) # ==> raw_dataset <TFRecordDatasetV2 shapes: (), types: tf.string>
raw_dataset = raw_dataset.repeat()
print('repeat', raw_dataset) # ==> repeat <RepeatDataset shapes: (), types: tf.string>
raw_dataset = raw_dataset.shuffle(shuffle_buffer_size)
print('shuffle', raw_dataset) # ==> shuffle <ShuffleDataset shapes: (), types: tf.string>
raw_dataset = raw_dataset.batch(batch_size, drop_remainder=True)
print('batch', raw_dataset) # ==> batch <BatchDataset shapes: (10,), types: tf.string>
raw_example = next(iter(raw_dataset))
parsed = tf.train.Example.FromString(raw_example.numpy()) # ==> read_tfrecords.py:25: RuntimeWarning: Unexpected end-group tag: Not all data was converted
print('parsed', parsed) # ==> ''
input = parsed.features.feature['input'].float_list.value
print('input', input) # ==> []
target = parsed.features.feature['target'].float_list.value
print('target', target) # ==> []
Вот результаты из кода:
raw_dataset <TFRecordDatasetV2 shapes: (), types: tf.string>
repeat <RepeatDataset shapes: (), types: tf.string>
shuffle <ShuffleDataset shapes: (), types: tf.string>
batch <BatchDataset shapes: (10,), types: tf.string>
read_tfrecords.py:25: RuntimeWarning: Unexpected end-group tag: Not all data was converted
parsed = tf.train.Example.FromString(raw_example.numpy())
parsed
input []
target []
В результате мне интересно, как я получаю партию из Tfrecords файл для обучения. Не могли бы вы дать совет?