РЕДАКТИРОВАТЬ 2:
Не уверен, что это будет быстрее, но вы также можете просто сделать что-то вроде этого. Он все еще опирается на расширенную индексацию, хотя и на смежные данные, так что, может быть, это немного лучше?:
import numpy as np
def get_time_series(data, indices, look_back):
# Make sure indices are big enough
indices = indices[indices >= look_back]
# Make indexing matrix
idx = indices[:, np.newaxis] + np.arange(-look_back, 0)
# Make batch
return data[idx]
Вы можете использовать его, например, так:
import numpy as np
def get_time_series(data, indices, look_back):
indices = indices[indices >= look_back]
idx = indices[:, np.newaxis] + np.arange(-look_back, 0)
return data[idx]
def make_batches(data, look_back, batch_size):
indices = np.random.permutation(np.arange(look_back, len(data) + 1))
for i in range(0, len(indices), batch_size):
yield get_time_series(data, indices[i:i + batch_size], look_back)
data = ...
look_back = ...
batch_size = ...
for batch in make_batches(data, look_back, batch_size):
# Use batch
EDIT:
Если вы хотите перетасовать примеры, вы можете сначала создать скользящее окно для всего набора данных (которое не должно занимать память или время), а затем взять пакеты из перетасованного индекса:
# Make sliding window with the previous function
data_sw = get_time_series(data, 0, look_back, len(data))
# Random index
batch_idx = np.random.permutation(len(data_sw))
# To get the first batch
batch = data_sw[batch_idx[:batch_size]]
Я думаю, что это делает то, что вы хотите, и должно быть намного быстрее, чем использование циклов:
import numpy as np
def get_time_series(data, index, look_back, batch_size):
from numpy.lib.stride_tricks import as_strided
# Index should be at least as big as look_back to have enough elements before it
index = max(index, look_back)
# Batch size should not go beyond the array
batch_size = min(batch_size, len(data) - index + 1)
# Relevant slice for the batch
data_slice = data[index - look_back:index + batch_size]
# Reshape with stride tricks as a "sliding window"
data_strides = data_slice.strides
batch_shape = (batch_size, look_back, data_slice.shape[-1])
batch_strides = (data_strides[0], data_strides[0], data_strides[1])
return as_strided(data_slice, batch_shape, batch_strides, writeable=False)
# Test
data = np.arange(300).reshape((100, 3))
batch = get_time_series(data, 20, 5, 4)
print(batch)
Выход:
[[[45 46 47]
[48 49 50]
[51 52 53]
[54 55 56]
[57 58 59]]
[[48 49 50]
[51 52 53]
[54 55 56]
[57 58 59]
[60 61 62]]
[[51 52 53]
[54 55 56]
[57 58 59]
[60 61 62]
[63 64 65]]
[[54 55 56]
[57 58 59]
[60 61 62]
[63 64 65]
[66 67 68]]]