tenorflow - входной конвейер с несколькими файлами TFRecord + tf.contrib.data.sliding_window_batch () - PullRequest
0 голосов
/ 13 октября 2018

У меня есть несколько TFRecord файлов, которые содержат определенный период моих данных.Содержащие точки данных являются последовательными внутри каждого файла, но не последовательными между файлами.Как часть моего входного конвейера, я использую tf.contrib.data.sliding_window_batch для обработки окна точек данных следующим образом:

filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
dataset = tf.data.TFRecordDataset(filenames)

dataset = dataset.map(parser_fn, num_parallel_calls=6)
dataset = dataset.map(preprocessing_fn, num_parallel_calls=6)
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=y + z)) # sliding window
dataset = dataset.map(lambda x: prepare_fn(x, y, z))
dataset = dataset.shuffle(buffer_size=100000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
dataset = dataset.prefetch(2)

Как я могу предотвратить перекрытие окна между точками данных из разных файлов?

Ответы [ 2 ]

0 голосов
/ 15 октября 2018

Альтернативой может быть создание пакетов для каждого файла независимо и чередование результатов:

def interleave_fn(filename):
  dataset = dataset.map(parser_fn, num_parallel_calls=6)
  dataset = dataset.map(preprocessing_fn, num_parallel_calls=6)
  dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=y + z)) # sliding window

filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.interleave(interleave_fn, num_parallel_calls=...)
dataset = dataset.map(lambda x: prepare_fn(x, y, z))
dataset = dataset.shuffle(buffer_size=1000000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
dataset = dataset.prefetch(2)

Это, вероятно, более производительно, так как обходит этап фильтрации.

0 голосов
/ 13 октября 2018

Решено с помощью tf.Dataset.filter(predicate).

filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
dataset = tf.data.TFRecordDataset(filenames)

dataset = dataset.map(parser_fn, num_parallel_calls=6)
dataset = dataset.map(preprocessing_fn, num_parallel_calls=6)
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=y + z)) # sliding window
dataset = dataset.filter(lambda x: tf.equal(x['timeframe'][0], x['timeframe'][-1]))
dataset = dataset.map(lambda x: prepare_fn(x, y, z))
dataset = dataset.shuffle(buffer_size=100000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
dataset = dataset.prefetch(2)
...