Keras fit_generator: случайное увеличение внутри генератора + перетасовка - PullRequest
0 голосов
/ 11 мая 2018

Я создал генератор для ввода его в функцию fit_generator кератов.Генератор создает несколько случайных значений.Вот как я это сделал:

class DataGenerator(object):
    def __init__(self, X_Y_file_path, batch_size, N):
        self.X_Y_file_path = X_Y_file_path
        self.batch_size = size
        self.N = N

    def initialize_zeros(self):
        X = np.zeros((self.batch_size, 1), dtype='int32')
        Y = np.zeros((self.batch_size, 1), dtype='int32')
        Y_neg = np.zeros((self.batch_size, self.N))
        return X, Y, Y_neg

     def generate(self):
        while True:
            i = 0 
            X, Y, Y_neg = initialize_zeros()
            for row in load_data_per_line(self.X_Y_file_path): # load_data_per_line is generator function which goes each line at a time from one file.
                x, y = row
                y_neg = random.sample(id_list, self.N) # a list of id to pick randomly
                X[i] = x
                Y[i] = y
                Y_neg[i] = y_neg
                if i == self.batch_size:
                    yield ([X, Y_neg], Y) # Y_neg goes as input in the model.(not important here. just mentioning)
                    X, Y, Y_neg = initialize_zeros()
                    i = 0

Так что это мой генератор.Тем не менее, с тем же примером данных, кажется, работает правильно.

Мне было интересно, как я могу реализовать в этом генераторе функцию тасования для перемешивания после каждой эпохи?

В поиске битов я узнал о последовательности , которую вы можете переопределить on_epoch_end метод, но не ясно, как я могу реализовать вышеупомянутый генератор с наследованием Sequence.Любая помощь по этому поводу?(кстати, вышеуказанная функция «безопасна» для использования use_multiprocessing в fit_generator?)

Редактировать

X_Y_file_path - это один файл (с известной длиной).load_data_per_line - это функция генератора, которая выдает по одной на строку.

1 Ответ

0 голосов
/ 11 мая 2018

Вы находитесь на правильном пути, чтобы использовать последовательность. Это гарантирует, что каждая точка данных будет видна один раз при использовании с мультиобработкой. Простой способ построения последовательности - это предварительная загрузка необработанных данных, выполняйте обработку на лету каждый раз, когда запрашивается партия.

class MySeq(Sequence):
    def __init__(self, X_Y_file_path, batch_size, N):
        self.X_Y_file_path = X_Y_file_path
        self.batch_size = size
        self.N = N
        self.data = load_data_per_line(self.X_Y_file_path)

     def __len__(self):
        return int(np.ceil(len(self.data) / self.batch_size))

     def __getitem__(self, idx):
        # Just slice the data based on batch index (idx)
        batch_data = self.data[idx*self.batch_size:(idx+1)*self.batch_size]
        X = np.zeros((len(batch_data), 1), dtype='int32')
        Y = np.zeros((len(batch_data), 1), dtype='int32')
        Y_neg = np.zeros((len(batch_data), self.N))
        for row, i in enumerate(data):
            x, y = row
            y_neg = random.sample(id_list, self.N) # a list of id to pick randomly
            X[i] = x
            Y[i] = y
            Y_neg[i] = y_neg
        return [X, Y_neg], Y # This is a single batch

Теперь вы можете выполнить любую обработку на вашем self.data, используя on_epoch_end()

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...