Пользовательский генератор данных для Keras LSTM с TimeSeriesGenerator - PullRequest
0 голосов
/ 14 мая 2018

Поэтому я пытаюсь использовать Keras ' fit_generator с настраиваемым генератором данных для подачи в сеть LSTM.

Что работает

Чтобы проиллюстрировать проблему,Я создал игрушечный пример, пытаясь предсказать следующее число в простой возрастающей последовательности, и я использую Keras TimeseriesGenerator для создания экземпляра Sequence:

WINDOW_LENGTH = 4
data = np.arange(0,100).reshape(-1,1)
data_gen = TimeseriesGenerator(data, data, length=WINDOW_LENGTH,
                               sampling_rate=1, batch_size=1)

Я использую простой LSTMсеть:

data_dim = 1
input1 = Input(shape=(WINDOW_LENGTH, data_dim))
lstm1 = LSTM(100)(input1)
hidden = Dense(20, activation='relu')(lstm1)
output = Dense(data_dim, activation='linear')(hidden)

model = Model(inputs=input1, outputs=output)
model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])

и обучаем ее с помощью функции fit_generator:

model.fit_generator(generator=data_gen,
                    steps_per_epoch=32,
                    epochs=10)

И это отлично тренируется, и модель делает прогнозы, как и ожидалось.

проблема

Теперь проблема в том, что в моей неигровой ситуации я хочу обработать данные, поступающие из TimeseriesGenerator, перед тем, как передать данные в fit_generator.В качестве шага к этому я создаю функцию генератора, которая просто оборачивает используемый ранее TimeseriesGenerator.

def get_generator(data, targets, window_length = 5, batch_size = 32):
    while True:
        data_gen = TimeseriesGenerator(data, targets, length=window_length, 
                                       sampling_rate=1, batch_size=batch_size)
        for i in range(len(data_gen)):
            x, y = data_gen[i]
            yield x, y

data_gen_custom = get_generator(data, data,
                                window_length=WINDOW_LENGTH, batch_size=1)

Но теперь странно то, что когда я тренирую модель, как раньше, но использую этот генератор в качестве входных данных,

model.fit_generator(generator=data_gen_custom,
                    steps_per_epoch=32,
                    epochs=10)

Нет ошибки, но ошибка обучения повсюду (прыжки вверх и вниз вместо последовательного падения, как это было с другим подходом), и модель не учитсяделать хорошие прогнозы.

Есть идеи, что я делаю не так с моим подходом генератора?

Ответы [ 2 ]

0 голосов
/ 04 июня 2019

У меня лично были проблемы с кодом nuric.По какой-то причине у меня была ошибка, из-за которой я не смог написать сценарий.Вот мое возможное исправление.Дайте мне знать, может ли это сработать?

class CustomGen(TimeseriesGenerator):
    def __getitem__(self, idx):
        x,y = super().__getitem__(idx)
        return x, y
0 голосов
/ 14 мая 2018

Это может быть связано с тем, что тип объекта изменяется с Sequence, то есть TimeseriesGenerator на универсальный генератор.Функция fit_generator обрабатывает их по-разному.Более чистым решением было бы унаследовать класс и переопределить бит обработки:

class CustomGen(TimeseriesGenerator):
  def __getitem__(self, idx):
    x, y = super()[idx]
    # do processing here
    return x, y

И использовать этот класс, как раньше, так как остальная внутренняя логика останется прежней.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...