tenorflow - tf.data.Dataset случайным образом пропускает образцы перед пакетированием, чтобы получить разные партии - PullRequest
0 голосов
/ 23 ноября 2018

Моя модель использует хронологически упорядоченные последовательности в каждой входной партии.Поэтому я создаю пакеты, прежде чем перетасовать свои входные данные.В связи с этим возникает проблема, заключающаяся в том, что пакеты всегда включают одни и те же выборки данных по всему набору данных (начиная с одних и тех же индексов, смещенных на batch_size). Я решил эту проблему путем кэширования исходного набора данных и выборки из пропущенных наборов данных, однако это израсходовало память.довольно быстро (хотя мой набор данных имеет только 150 МБ):

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.window(size=window_size, shift=window_shift, stride=window_stride, drop_remainder=True).flat_map(lambda x: x.batch(window_size))
dataset = dataset.map(process_fn, num_parallel_calls=8)
dataset = dataset.cache()
datasets = []
for i in range(0, batch_size):
    d = dataset.skip(i)
    d = d.batch(batch_size, drop_remainder=True)
    datasets.append(d)
dataset = tf.data.experimental.sample_from_datasets(datasets)
dataset = dataset.shuffle(buffer_size=30000, reshuffle_each_iteration=False)
dataset = dataset.repeat()

Есть ли другой способ добиться такого поведения?Я хочу охватить все возможные индексы для начала первой последовательности внутри пакета.

1 Ответ

0 голосов
/ 04 апреля 2019

Вы поглощаете память, потому что вы перетасовываете целые партии - также пропуск может быть не очень эффективным.Поскольку ваши данные кажутся целыми в памяти, вы, возможно, могли бы сэмплировать свои данные непосредственно в python, не слишком заботясь о производительности:

def make_batch(start_idx):
  batch = np.empty((batch_size, window_size), dtype=data.dtype)
  for batch_idx, data_idx in enumerate(
      range(start_idx, start_idx + window_shift * batch_size, window_shift)):
    batch[batch_idx] = data[data_idx:data_idx + window_size * window_stride:window_stride]
  return batch

dataset = (tf.data.Dataset
  .range(len(data) - window_stride * (window_size - 1) - window_shift * (batch_size- 1))
  .shuffle(buffer_size=30000, reshuffle_each_iteration=False)
  .map(lambda x: tf.py_func(make_batch, [x], tf.float32)) # assuming your data is float32
  .repeat()
  .prefetch(1)) # you might want to consider prefetching for performance

Перестановка теперь происходит по индексам, а не по целым пакетам, поэтомунамного меньший объем памяти.

...