Извлечение состояний из каждой партии - train_on_batch в Keras (Tensorflow) - PullRequest
0 голосов
/ 02 ноября 2018

Чего я пытаюсь достичь

Я экспериментирую с настраиваемым генератором пакетов для модели последовательности к последовательности:

  • Входы - это длинные последовательности, которые не помещаются в одну партию
  • Последовательности сильно различаются по длине
  • Каждый раз, когда генерируется партия, генератор партии смещает последовательности на определенную величину
  • Если какая-либо отдельная последовательность достигает конца, генератор партии заменяет эту конкретную часть партии новой последовательностью
  • Я хочу, чтобы состояние из предыдущей партии сохранялось в следующей партии; в то время как все новые добавленные последовательности заменяются нулевым состоянием

Что я пытаюсь сделать

В моей модели есть слой LSTM, который возвращает последовательность и конечные состояния:

lstm_layer = tf.keras.layers.CuDNNLSTM(state_size, return_sequences=True, return_state=True, stateful=True)
lstm,state_h,state_c = lstm_layer(concat_inputs)

Во время обучения мой batch_generator принимает в качестве входных данных предыдущее состояние и обновляет его. Затем, перед началом обучения, я сбрасываю состояние в слое lstm, используя недавно обновленные состояния:

(batch,states) = batch_generator.next_batch(last_states,seq_length)
lstm_layer.reset_states(states)

Затем я тренирую партию, используя train_on_batch:

loss = model.train_on_batch(batch[0],y=batch[1])

проблема

Я не смог найти способ извлечения состояний lstm (state_h,state_c) из модели в конце каждой партии.

Мой обходной путь

В настоящее время я использую обходной путь с небольшим запахом кода:

model._make_train_function()
model.train_function.outputs += [state_h,state_c]
x,y,sample_weights = model._standardize_user_data(inputs, targets, None, None)
outputs = model.train_function(x + y + sample_weights)
loss = outputs[0] #loss = model.train_on_batch(inputs,y=targets)
last_states = (outputs[2],outputs[3])

Есть ли лучший способ сделать это, не связанный с написанием моего собственного train_on_batch?

...