Поэтому я пытаюсь использовать Keras ' fit_generator с настраиваемым генератором данных для подачи в сеть LSTM.
Что работает
Чтобы проиллюстрировать проблему,Я создал игрушечный пример, пытаясь предсказать следующее число в простой возрастающей последовательности, и я использую Keras TimeseriesGenerator для создания экземпляра Sequence:
WINDOW_LENGTH = 4
data = np.arange(0,100).reshape(-1,1)
data_gen = TimeseriesGenerator(data, data, length=WINDOW_LENGTH,
sampling_rate=1, batch_size=1)
Я использую простой LSTMсеть:
data_dim = 1
input1 = Input(shape=(WINDOW_LENGTH, data_dim))
lstm1 = LSTM(100)(input1)
hidden = Dense(20, activation='relu')(lstm1)
output = Dense(data_dim, activation='linear')(hidden)
model = Model(inputs=input1, outputs=output)
model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])
и обучаем ее с помощью функции fit_generator
:
model.fit_generator(generator=data_gen,
steps_per_epoch=32,
epochs=10)
И это отлично тренируется, и модель делает прогнозы, как и ожидалось.
проблема
Теперь проблема в том, что в моей неигровой ситуации я хочу обработать данные, поступающие из TimeseriesGenerator, перед тем, как передать данные в fit_generator
.В качестве шага к этому я создаю функцию генератора, которая просто оборачивает используемый ранее TimeseriesGenerator.
def get_generator(data, targets, window_length = 5, batch_size = 32):
while True:
data_gen = TimeseriesGenerator(data, targets, length=window_length,
sampling_rate=1, batch_size=batch_size)
for i in range(len(data_gen)):
x, y = data_gen[i]
yield x, y
data_gen_custom = get_generator(data, data,
window_length=WINDOW_LENGTH, batch_size=1)
Но теперь странно то, что когда я тренирую модель, как раньше, но использую этот генератор в качестве входных данных,
model.fit_generator(generator=data_gen_custom,
steps_per_epoch=32,
epochs=10)
Нет ошибки, но ошибка обучения повсюду (прыжки вверх и вниз вместо последовательного падения, как это было с другим подходом), и модель не учитсяделать хорошие прогнозы.
Есть идеи, что я делаю не так с моим подходом генератора?