Генератор Кераса продолжает перемешивать, хотя его просят не - PullRequest
0 голосов
/ 17 марта 2020

Я использую генератор данных Keras, по умолчанию инициализирующий shuffle со значением false:

class data_generator(keras.utils.Sequence):
    def __init__(self, frames, labels, batch_size, data_dir, shuffle=False):
        'Initialization'
        self.batch_size = batch_size
        self.labels = labels
        self.frames = frames
        self.data_dir = data_dir
        self.shuffle = shuffle
        self.size = len(self.frames)
        self.on_epoch_end()

  ...

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.frames))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

   ...

И вот как я создаю экземпляры для обучения и проверки:

train_generator = data_generator(x_train[:num_train_examples], y_train[:num_train_examples], batch_size, data_dir)
val_generator = data_generator(x_train[num_train_examples:], y_train[num_train_examples:], batch_size, data_dir)

И затем обучаю модель:

model.fit_generator(train_generator,
                        validation_data=val_generator,
                        callbacks=[history],
                        epochs=num_epochs)

Но генератор продолжает выдавать случайные индексы:

starting training
Epoch 1/1

batch start: 0, batch end: 2

batch start: 24, batch end: 26

batch start: 2, batch end: 4

batch start: 114, batch end: 116

batch start: 4, batch end: 6

batch start: 60, batch end: 62

batch start: 6, batch end: 8

batch start: 68, batch end: 70

batch start: 8, batch end: 10

batch start: 94, batch end: 96

Что я могу сделать, чтобы не перемешать?

A getitem функция из класса генератора:

    def __getitem__(self, index):
        'Generate one batch of data'
        x_batch, y_batch = self.__data_generation(index)

        return x_batch, y_batch

    def __data_generation(self, index):
        'Generates data containing batch_size samples'
        limit = min(self.size, (index + 1)*self.batch_size)
        x_batch = []
        print('\nbatch start: ' + str(index*self.batch_size) + ', batch end: ' + str(limit))
        for frame in self.frames[index*self.batch_size:limit]:
            video_array = np.load(self.data_dir + '/' + frame + '.npy')
            x_batch.append(np.array(video_array))

        return np.array(x_batch), self.labels[index*self.batch_size:limit]

РЕДАКТИРОВАТЬ: теперь я могу видеть шаблон, выглядит как случайные партии чередуются со случайными

1 Ответ

0 голосов
/ 18 марта 2020

Я предполагаю, что проблема может быть в вашей функции __len__(self) (если вы ее определили). Я добавил функцию __len__(self) в ваш код и попробовал, теперь она не тасуется. Код здесь:

class data_generator(keras.utils.Sequence):
    def __init__(self, frames, labels, batch_size, data_dir, shuffle=False):
        'Initialization'
        self.batch_size = batch_size
        self.labels = labels
        self.frames = frames
        self.data_dir = data_dir
        self.shuffle = shuffle
        self.size = len(self.frames)
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(self.size/self.batch_size))

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.frames))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __getitem__(self, index):
        'Generate one batch of data'
        x_batch, y_batch = self.__data_generation(index)
        return x_batch, y_batch

    # def __data_generation(self, index):
    #     'Generates data containing batch_size samples'
    #     current_indices = self.indexes[index*self.batch_size:(index + 1)*self.batch_size]
    #     x_batch = []
    #     y_batch = []
    #     for idx in current_indices:
    #         # video_array = np.load(self.data_dir + '/' + self.frames[idx] + '.npy')
    #         # x_batch.append(np.array(video_array))
    #         y_batch.append(self.labels[idx])

    #     return np.array(x_batch), y_batch

    def __data_generation(self, index):
        'Generates data containing batch_size samples'
        limit = min(self.size, (index + 1)*self.batch_size)
        x_batch = []
        print('\nbatch start: ' + str(index*self.batch_size) + ', batch end: ' + str(limit))
        for frame in self.frames[index*self.batch_size:limit]:
            video_array = np.load(self.data_dir + '/' + frame + '.npy')
            x_batch.append(np.array(video_array))
        return np.array(x_batch), self.labels[index*self.batch_size:limit]

Приведенный выше код работает, как вы ожидали, он не тасует. Однако, как вы определили свою функцию __ data_generation , она не будет работать, если вы хотите, чтобы она перемешалась. Поэтому я написал свою собственную __data_generation функцию, которую вы можете увидеть закомментированной. Если вы используете это, вы можете получить желаемую функциональность. Если shuffle равно True , оно будет перетасовано. Если shuffle равно False , оно не будет перемешиваться. Надеюсь, это поможет.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...