Keras model.reset_states () не работает с tf.train.MonitoredTrainingSession - PullRequest
1 голос
/ 09 июня 2019

Я хочу использовать tf.train.MonitoredTrainingSession() для обучения модели, описанной в Keras. Эта модель является моделью с состоянием, поэтому я хочу сбрасывать состояния после каждой эпохи.

Одна проблема заключается в том, что, если я позвоню model.reset_states(), он выдаст следующую ошибку.

RuntimeError: График завершен и не может быть изменен.

Если вместо tf.train.MonitoredTrainingSession() используется tf.Session(), эта ошибка не появляется.

Например, в следующем примере кода, даже если это не обучающий код, генерируется то же сообщение об ошибке.

#!/usr/bin/python

import tensorflow as tf


inputs1 = tf.reshape(tf.linspace(0.0, 100.0, 10), (1, 2, 5)) 
inputs2 = tf.reshape(tf.linspace(100.0, 0.0, 10), (1, 2, 5)) 

model = tf.keras.Sequential([
    tf.keras.layers.LSTM(
    5,  
    return_sequences=True, stateful=True)
])

outputs1 = model(inputs1)
outputs2 = model(inputs2)

with tf.train.MonitoredTrainingSession() as sess:
  model.reset_states()
  print (sess.run(outputs1))
  model.reset_states()
  print (sess.run(outputs2))

Я нашел два способа решения этой проблемы:

  1. Для использования tf.get_current_graph()._unsafe_unfinalize() перед сбросом статистики.

  2. Для использования tf.Session() вместо tf.train.MonitoedTrainingSession().

Но я думаю, что оба подхода не идеальны. Не могли бы вы предложить, какое было бы лучшее решение в этом случае?

...