keras fit_generator читает фрагменты из hdfstore - PullRequest
0 голосов
/ 27 апреля 2018

Я пытаюсь построить генератор для модели Keras, который будет обучаться в большом магазине hdf. Чтобы ускорить обучение, я предварительно рассчитал все функции, в т.ч. горячая кодировка уже в hdfstore. Таким образом, призыв от этого должен быть прямым.

Чтобы передать куски моих данных в сеть, я пытаюсь использовать fit_generator, но изо всех сил пытаюсь его запустить и запустить.

Генератор:

def myGenerator(myStore, generateFrom,generateTo):
 # Create empty arrays to contain batch of features and labels#
    while True:
        X = pd.read_hdf(myStore,'X',start=generateFrom,stop=generateTo)
        y = pd.read_hdf(myStore,'y',start=generateFrom,stop=generateTo)
        yield X,y

Сеть и установка:

def get_model(shape):
    '''Create a keras model.'''
    inputlayer = Input(shape=shape)

    model = BatchNormalization()(inputlayer)
    model = Dense(1024, activation='relu')(model)
    model = Dropout(0.25)(model)
    model = BatchNormalization()(inputlayer)
    model = Dense(512, activation='relu')(model)
    model = Dropout(0.25)(model)
    model = BatchNormalization()(inputlayer)
    model = Dense(256, activation='relu')(model)
    model = Dropout(0.25)(model)
    model = BatchNormalization()(inputlayer)
    model = Dense(128, activation='relu')(model)
    model = Dropout(0.25)(model)

    # 11 because background noise has been taken out
    model = Dense(2, activation='tanh')(model)

    model = Model(inputs=inputlayer, outputs=model)

    return model
shape = (6603,10000)
model = get_model(shape)
model.compile(loss='mean_squared_error', optimizer=Adam(), metrics=['accuracy'])
#X = generator(myStore)
#Xt = generator(myStore)
labelbinarizer = LabelBinarizer()
y = labelbinarizer.fit_transform(y)
#yt = labelbinarizer.fit_transform(yt)

generateFrom = 0
for i in range(10):
    generateTo=generateFrom+10000
    model.fit_generator(
        generator=myGenerator(myStore,generateFrom,generateTo),
        epochs=1,
        steps_per_epoch=X[0].shape[0] // 1000)
    generateFrom=generateTo

Я пробовал и то и другое: поместить fit_generator в петлю и подключить диапазон (как показано выше), а также обработать диапазон внутри генератора. Оба не работают. В настоящее время работает в

TypeError: 'generator' object is not subscriptable

Скорее всего, у меня возникло недопонимание того, как fit_generator () должен использоваться в этом контексте. Большинство примеров вокруг генерируют тензоры из изображений.

Любая подсказка приветствуется. Спасибо

1 Ответ

0 голосов
/ 27 апреля 2018

Функция read_hdf возвращает объект panda, вам нужно преобразовать его в массив numpy.

...