Потокобезопасный пользовательский генератор для keras `fit_generator ()` - PullRequest
0 голосов
/ 20 ноября 2018

Я последовал примеру поточно-безопасного генератора для Keras fit_generator, приведенному здесь: https://keras.io/utils/#sequence Похоже, что индекс пакета (idx) заблокирован для каждого потока. В моем случае я хочу заблокировать поток для примера индекса. Вот моя реализация:

class CustomGenerator():

    def __init__(self):
        self.input_ = np.arange(0, 1000)
        self.labels = np.arange(0, 1000) * 0.1
        self.batch_sz = 5
        self.example_index = 0

    def __len__(self):
        return np.ceil(len(self.input_) / float(self.batch_sz))

    def __getitem__(self, batch_idx):
        batch_x = np.zeros(self.batch_sz)
        batch_y = np.zeros(self.batch_sz)
        row = 0
        while row < self.batch_sz:
            if self.example_index % 2 == 0:
                batch_x[row] = self.input_[self.example_index]
                batch_y[row] = self.labels[self.example_index]
                row += 1
            self.example_index += 1

        return batch_x, batch_y

cg = CustomGenerator()
batch_idx = 0

while True:
    print(cg.__getitem__(batch_idx))
    batch_idx += 1

Выводит правильный вывод:

(array([0., 2., 4., 6., 8.]), array([0. , 0.2, 0.4, 0.6, 0.8]))
(array([10., 12., 14., 16., 18.]), array([1. , 1.2, 1.4, 1.6, 1.8]))
(array([20., 22., 24., 26., 28.]), array([2. , 2.2, 2.4, 2.6, 2.8]))
(array([30., 32., 34., 36., 38.]), array([3. , 3.2, 3.4, 3.6, 3.8]))

Как я могу убедиться, что эта реализация будет работать в поточно-ориентированном режиме, т. Е. Разные работники не будут использовать один и тот же example_index при создании пакетов.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...