Правильное использование fit_generator в керасе - PullRequest
1 голос
/ 20 мая 2019

Всех приветствую.Пытаясь понять, как fit_generator работает в keras.

У меня есть набор данных, в каждом файле - 100 изображений и 100 меток.

Я написал этот генератор:

def GenerateData(self):

    while True:

        complete_x1 = np.zeros((500, 50, 50, 3))
        complete_x2 = np.zeros((500, 50, 50, 3))
        complete_y1 = np.zeros((500, 3))
        complete_y2 = np.zeros((500, 2))

        done = 0

        while done < 500:

            data = np.load("{}/data_resized_{}.npy".format(self._patch, self._LastID))

            self.Log('\nLoad ALL data. ID: {} - Done: {}'.format(self._LastID, done))

            for data_x1, data_x2, data_y1, data_y2 in data:

                data_x1 = self.random_transform(data_x1)

                data_x2 = self.random_transform(data_x2)

                data_x1 = self.ImageProcessing(data_x1, 0)

                data_x2 = self.ImageProcessing(data_x2, 1)

                data_x1 = np.array(data_x1).astype('float32')
                data_x1 /= 255

                data_x2 = np.array(data_x2).astype('float32')
                data_x2 /= 255

                complete_x1[done] = data_x1
                complete_x2[done] = data_x2

                complete_y1[done] = data_y1
                complete_y2[done] = data_y2

                done += 1

            self._LastID += 1

            if self._LastID >= 1058:
                self._LastID = 0

        yield [np.array(complete_x1), np.array(complete_x2)], [np.array(complete_y1), np.array(complete_y2)]

IВсего 1058 файлов.Получается 105800 изображений с метками.

Модель обучения:

model.fit_generator(data.GenerateData(), samples_per_epoch=1058/500, nb_epoch=15, verbose=1, workers=1)

Все вроде бы хорошо, но!

В самом начале обучения GenerateData распечатываетследующее:

Загрузить ВСЕ данные.ID: 0 - Выполнено: 0

Загрузить ВСЕ данные.ID: 1 - Готово: 100

Загрузить ВСЕ данные.ID: 2 - Готово: 200

Загрузить ВСЕ данные.ID: 3 - Готово: 300

Загрузить ВСЕ данные.ID: 4 - Готово: 400

Загрузить ВСЕ данные.ID: 5 - Готово: 0

Загрузить ВСЕ данные.ID: 6 - Готово: 100

Загрузить ВСЕ данные.ID: 7 - Готово: 200

Загрузить ВСЕ данные.ID: 8 - Готово: 300

Загрузить ВСЕ данные.ID: 9 - Готово: 400

Загрузить ВСЕ данные.ID: 10 - Готово: 0

И это происходит до того, как файл с идентификатором 59. Получается ... Пропускает ли он все, что идет до 59 файла?5900 изображений?

Он просто загружает 500 изображений, после чего проходит выход и запускается снова, с идентификатором файла, на котором он закончил, но поезд не работает.

Вотчто следует после 59-го файла:

Загрузить ВСЕ данные.ID: 59 - Выполнено: 400 1/2 [=============> ................] - ETA: 4s - потеря: 2.8177- density_18_loss: 2.0145 - dens_21_loss: 0.8032 - dens_18_acc: 0,2140 - dens_21_acc: 0,5780 Загрузить ВСЕ данные.ID: 60 - Готово: 0

Загрузить ВСЕ данные.ID: 61 - Готово: 100

Загрузить ВСЕ данные.ID: 62 - Выполнено: 200

Загрузить ВСЕ данные.ID: 63 - Готово: 300

Загрузить ВСЕ данные.ID: 64 - Выполнено: 400 2/2 [===========================> ..] - ETA: 0s - потеря: 2.7260- density_18_loss: 1.7077 - dens_21_loss: 1.0183 - dens_18_acc: 0.2720 - dens_21_acc: 0.5890 Загрузить ВСЕ данные.ID: 65 - Готово: 0

Загрузить ВСЕ данные.ID: 66 - Готово: 100

Почему это происходит?

1 Ответ

1 голос
/ 20 мая 2019

Вы получаете это поведение, потому что вы установили workers в 1, и задача создания данных и задача обучения выполняются в отдельных потоках. Задача обучения выполняется в главном потоке, а задача создания данных - в отдельном потоке, где число потоков зависит от аргумента workers.

Если бы аргумент workers был равен 0, генератор данных работал бы в главном потоке, и результат был бы тем, что вы ожидаете.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...