Построение генератора для fit_generator (), который увеличивает размер пакета с увеличением эпохи - PullRequest
0 голосов
/ 24 октября 2018

Что касается следующей статьи: Не снижайте скорость обучения, увеличьте размер партии

TL; DR Как настроитьгенератор, который увеличивает размер пакета с увеличением эпохи?

(читайте дальше, только если вы хотите помочь с редактированием моего кода)


Цель: иметь набор обучающих данных (длярегрессия) (x_train, y_train) реализовать пакетный генератор для ANN, встроенного в Keras.Основная идея кода:

  1. Какое количество эпох?В зависимости от ответа установите размер пакета
  2. До конца набора данных генерируйте партии с заданным размером пакета (за исключением последнего, размер которого может быть меньше)

Я предполагал, что когда model.fit_generator(data_gen(x_train, y_train)) достигнет последней партии, она перейдет в новую эпоху обучения.

Может ли кто-нибудь подсказать, как проверить правильность реализации генератора партий?Вот мой код, в конце он выдает ошибку, и я не совсем понимаю, как заставить его работать правильно.

Код

def data_gen(features, targets):
    global epoch
    batches_produced_in_current_epoch = 1
    epoch += 1

    while True:
        print('==============================================================================')
        total_amount_of_samples = features.shape[0]

        if epoch <= 2:
            batch_size = 2 ** 4
        elif epoch <= 3:
            batch_size = 2 ** 5
        elif epoch <= 2000:    
            batch_size = 2 ** 6 

        how_many_batches_in_this_epoch_total = features.shape[0] / batch_size

        if not how_many_batches_in_this_epoch_total.is_integer():
            how_many_batches_in_this_epoch_total = int(how_many_batches_in_this_epoch_total) + 1
        else:
            how_many_batches_in_this_epoch_total = int(how_many_batches_in_this_epoch_total)

        print('Batch size:\t'+str(batch_size))
        print('Batches produced in the current epoch:\t'+str(batches_produced_in_current_epoch))
        print('How many batches in the epoch:\t{}'.format(how_many_batches_in_this_epoch_total))


        if int(batches_produced_in_current_epoch) == int(how_many_batches_in_this_epoch_total - 1):
            print('hi')
            batch_x = np.zeros((int(features.shape[0] % batch_size), features.shape[1]))
            batch_y = np.zeros((int(targets.shape[0] % batch_size), targets.shape[1]))
        else:
            batch_x = np.zeros((batch_size, features.shape[1]))
            batch_y = np.zeros((batch_size, targets.shape[1]))

        print('batch sizes:\t{}, {}'.format(batch_x.shape, batch_y.shape))

        if int(batches_produced_in_current_epoch) == int(how_many_batches_in_this_epoch_total - 1):
            batch_x[:,:] = x_train[batches_produced_in_current_epoch*batch_size:,:]
            batch_y[:,:] = y_train[batches_produced_in_current_epoch*batch_size:,:]
        else:
            batch_x[:,:] = x_train[batches_produced_in_current_epoch*batch_size:(batches_produced_in_current_epoch + 1)*batch_size,:]
            batch_y[:,:] = y_train[batches_produced_in_current_epoch*batch_size:(batches_produced_in_current_epoch + 1)*batch_size,:]

        print('Shapes:\t{}'.format(batch_x.shape, batch_y.shape))
        print('Batch size:\t'+str(batch_size))
        print('Batches produced in the current epoch:\t'+str(batches_produced_in_current_epoch))
        print('How many batches in the epoch:\t{}'.format(how_many_batches_in_this_epoch_total))

        batches_produced_in_current_epoch += 1

        if batches_produced_in_current_epoch == how_many_batches_in_this_epoch_total:
            epoch += 1    



        yield batch_x, batch_y

Тогда для:

Код

import numpy as np 

x_train = np.random.randn(20,6)
y_train = np.random.randn(20,1)
epoch = 0
for x, y in data_gen(x_train, y_train):
    print(x.shape, y.shape)

Я получаювывод:

Вывод и сообщение об ошибке

==============================================================================
Batch size: 16
Batches produced in the current epoch:  1
How many batches in the epoch:  2
hi
batch sizes:    (4, 6), (4, 1)
Shapes: (4, 6)
Batch size: 16
Batches produced in the current epoch:  1
How many batches in the epoch:  2
(4, 6) (4, 1)
==============================================================================
Batch size: 16
Batches produced in the current epoch:  2
How many batches in the epoch:  2
batch sizes:    (16, 6), (16, 1)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-145-2883bc9d95ea> in <module>()
----> 1 for x, y in data_gen(x_train, y_train):
      2     print(x.shape, y.shape)

<ipython-input-144-191d762315fa> in data_gen(features, targets)
     57             batch_y[:,:] = y_train[batches_produced_in_current_epoch*batch_size:,:]
     58         else:
---> 59             batch_x[:,:] = x_train[batches_produced_in_current_epoch*batch_size:(batches_produced_in_current_epoch + 1)*batch_size,:]
     60             batch_y[:,:] = y_train[batches_produced_in_current_epoch*batch_size:(batches_produced_in_current_epoch + 1)*batch_size,:]
     61 

ValueError: could not broadcast input array from shape (0,6) into shape (16,6)

Спасибо.

...