Я последовал примеру поточно-безопасного генератора для Keras fit_generator
, приведенному здесь: https://keras.io/utils/#sequence
Похоже, что индекс пакета (idx
) заблокирован для каждого потока. В моем случае я хочу заблокировать поток для примера индекса. Вот моя реализация:
class CustomGenerator():
def __init__(self):
self.input_ = np.arange(0, 1000)
self.labels = np.arange(0, 1000) * 0.1
self.batch_sz = 5
self.example_index = 0
def __len__(self):
return np.ceil(len(self.input_) / float(self.batch_sz))
def __getitem__(self, batch_idx):
batch_x = np.zeros(self.batch_sz)
batch_y = np.zeros(self.batch_sz)
row = 0
while row < self.batch_sz:
if self.example_index % 2 == 0:
batch_x[row] = self.input_[self.example_index]
batch_y[row] = self.labels[self.example_index]
row += 1
self.example_index += 1
return batch_x, batch_y
cg = CustomGenerator()
batch_idx = 0
while True:
print(cg.__getitem__(batch_idx))
batch_idx += 1
Выводит правильный вывод:
(array([0., 2., 4., 6., 8.]), array([0. , 0.2, 0.4, 0.6, 0.8]))
(array([10., 12., 14., 16., 18.]), array([1. , 1.2, 1.4, 1.6, 1.8]))
(array([20., 22., 24., 26., 28.]), array([2. , 2.2, 2.4, 2.6, 2.8]))
(array([30., 32., 34., 36., 38.]), array([3. , 3.2, 3.4, 3.6, 3.8]))
Как я могу убедиться, что эта реализация будет работать в поточно-ориентированном режиме, т. Е. Разные работники не будут использовать один и тот же example_index
при создании пакетов.