Как объединить несколько наборов данных tfrecord в один набор данных? - PullRequest
1 голос
/ 18 марта 2019

Предположим, у меня есть 3 файла tfrecord, а именно neg.tfrecord, pos1.tfrecord, pos2.tfrecord.

Размер моей партии равен 500, включая 300 отрицательных данных, 100 данных pos1 и 100 данных pos2. Как я могу получить желаемый TFRecordDataset?

Я буду использовать этот объект TFRecordDataset в keras.fit () (Eager Execution).

Моя версия тензорного потока - 1.13.1. Я нахожу API в tf.data.Dataset , например interleave, concatenate, zip, но, похоже, я не могу решить свою проблему.

Раньше я пытался получить итератор для каждого набора данных, а затем вручную получать данные после получения данных, но это было неэффективно и использование графического процессора было невысоким.

А в этом вопросе я использую interleave ниже:

tfrecord_files = ['neg.tfrecord', 'pos1.tfrecord', 'pos2.tfrecord']
dataset = tf.data.Dataset.from_tensor_slices(tfrecord_files)
def _parse(x):
    x = tf.data.TFRecordDataset(x)
    return x
dataset = dataset.interleave(_parse, cycle_length=4, block_length=1)
dataset = dataset.apply(tf.data.experimental.map_and_batch(_parse_image_function, 500))

и я получил эту партию:

neg pos1 pos2 neg pos1 pos2 ...............

Но я хочу вот что:

neg neg neg pos1 pos2 neg neg neg pos1 pos2 .................

Что мне делать?

С нетерпением жду ответа.

1 Ответ

1 голос
/ 18 марта 2019

Я воспроизвел что-то вроде того, что вы сказали, используя строковые данные:

import tensorflow as tf

def string_data(s):
    return tf.sparse.to_dense(tf.strings.split([s]), default_value='')[0]

data = [' '.join(['neg'] * 30), ' '.join(['pos1'] * 10), ' '.join(['pos2'] * 10)]
step_sizes = tf.constant([3, 1, 1], dtype=tf.int64)
ds = (tf.data.Dataset.from_tensor_slices((data, step_sizes))
      .interleave(lambda d, s: (tf.data.Dataset.from_tensor_slices(string_data(d))
                                .batch(s)),
                  cycle_length=len(data))
      .flat_map(tf.data.Dataset.from_tensor_slices))
iter = ds.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    while True:
        try:
            print(sess.run(iter).decode(), end=', ')
        except tf.errors.OutOfRangeError: break
    print()

Вывод:

neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, 

В реальном случае вы бы заменили data списком файловимена и tf.data.Dataset.from_tensor_slices(string_data(d)) с tf.data.TFRecordDataset(d), но в остальном все должно работать аналогично.

РЕДАКТИРОВАТЬ: Я только что понял, что вы на самом деле хотели, чтобы пакет всех элементов был упорядочен таким образом, а не только один элемент за раз,так что я полагаю, вам нужно добавить еще один batch вызов в конце.

...