Регулярные нейронные сети с отслеживанием состояния с помощью fit_generator () - PullRequest
0 голосов
/ 31 мая 2019

Контекст

Я читал некоторые блоги о внедрении рекуррентных нейронных сетей с сохранением состояния в Керасе (например, здесь и здесь ).
Существует также несколько вопросов, касающихся RNN с отслеживанием состояния в стеке, в результате чего этот вопрос приближается к моему.

В связанных руководствах используется fit() -методвместо fit_generator() и передачи состояний путем ручной итерации по эпохам с epochs=1 в fit(), как в этом примере взято из здесь :

# fit an LSTM network to training data
def fit_lstm(train, batch_size, nb_epoch, neurons):
    X, y = train[:, 0:-1], train[:, -1]
    X = X.reshape(X.shape[0], 1, X.shape[1])
    model = Sequential()
    model.add(LSTM(neurons, batch_input_shape=(batch_size, X.shape[1], X.shape[2]), stateful=True))
    model.add(Dense(1))
    model.compile(loss='mean_squared_error', optimizer='adam')
    for i in range(nb_epoch):
        model.fit(X, y, epochs=1, batch_size=batch_size, verbose=0, shuffle=False)
        model.reset_states()
    return model


Мой вопрос

Я бы хотел использовать fit_generator() вместо fit(), но также использовать LSTM / GRU-слои без сохранения состояния.Что мне не хватало в других вопросах stackoverflow, таких как приведенный выше:

  1. Могу ли я действовать так же, как с fit(), что означает установку epochs=1, и повторять его по x раз, устанавливая model.reset_states() в каждой итерации, как в примере?
  2. Или fit_generator() уже сбрасывает состояния только после завершения batch_size, когдаstateful=True используется (что было бы здорово)?
  3. Или fit_generator() сбрасывает состояния после каждой отдельной партии (что может быть проблемой)?

Последний вопрос касается, в частности, этой формы заявления здесь :

Без состояния : В конфигурации LSTM без состояния внутреннее состояние сбрасываетсяпосле каждой обучающей партии или каждой партии при прогнозировании.
С состоянием : В конфигурации с LSTM с состоянием внутреннее состояние сбрасывается только при вызове функции reset_state ().

...