Вы можете использовать функцию, подобную этой:
import tensorflow as tf
def split_tfrecord(tfrecord_path, split_size):
with tf.Graph().as_default(), tf.Session() as sess:
ds = tf.data.TFRecordDataset(tfrecord_path).batch(split_size)
batch = ds.make_one_shot_iterator().get_next()
part_num = 0
while True:
try:
records = sess.run(batch)
part_path = tfrecord_path + '.{:03d}'.format(part_num)
with tf.python_io.TFRecordWriter(part_path) as writer:
for record in records:
writer.write(record)
part_num += 1
except tf.errors.OutOfRangeError: break
Например, чтобы разделить файл my_records.tfrecord
на части по 100 записей в каждой, вы должны сделать:
split_tfrecord(my_records.tfrecord, 100)
Это создаст несколько небольших файлов записей my_records.tfrecord.000
, my_records.tfrecord.001
и т. Д.