Вы поглощаете память, потому что вы перетасовываете целые партии - также пропуск может быть не очень эффективным.Поскольку ваши данные кажутся целыми в памяти, вы, возможно, могли бы сэмплировать свои данные непосредственно в python, не слишком заботясь о производительности:
def make_batch(start_idx):
batch = np.empty((batch_size, window_size), dtype=data.dtype)
for batch_idx, data_idx in enumerate(
range(start_idx, start_idx + window_shift * batch_size, window_shift)):
batch[batch_idx] = data[data_idx:data_idx + window_size * window_stride:window_stride]
return batch
dataset = (tf.data.Dataset
.range(len(data) - window_stride * (window_size - 1) - window_shift * (batch_size- 1))
.shuffle(buffer_size=30000, reshuffle_each_iteration=False)
.map(lambda x: tf.py_func(make_batch, [x], tf.float32)) # assuming your data is float32
.repeat()
.prefetch(1)) # you might want to consider prefetching for performance
Перестановка теперь происходит по индексам, а не по целым пакетам, поэтомунамного меньший объем памяти.