Я хочу использовать tf.train.MonitoredTrainingSession()
для обучения модели, описанной в Keras. Эта модель является моделью с состоянием, поэтому я хочу сбрасывать состояния после каждой эпохи.
Одна проблема заключается в том, что, если я позвоню model.reset_states()
, он выдаст следующую ошибку.
RuntimeError: График завершен и не может быть изменен.
Если вместо tf.train.MonitoredTrainingSession()
используется tf.Session()
, эта ошибка не появляется.
Например, в следующем примере кода, даже если это не обучающий код, генерируется то же сообщение об ошибке.
#!/usr/bin/python
import tensorflow as tf
inputs1 = tf.reshape(tf.linspace(0.0, 100.0, 10), (1, 2, 5))
inputs2 = tf.reshape(tf.linspace(100.0, 0.0, 10), (1, 2, 5))
model = tf.keras.Sequential([
tf.keras.layers.LSTM(
5,
return_sequences=True, stateful=True)
])
outputs1 = model(inputs1)
outputs2 = model(inputs2)
with tf.train.MonitoredTrainingSession() as sess:
model.reset_states()
print (sess.run(outputs1))
model.reset_states()
print (sess.run(outputs2))
Я нашел два способа решения этой проблемы:
Для использования tf.get_current_graph()._unsafe_unfinalize()
перед сбросом статистики.
Для использования tf.Session()
вместо tf.train.MonitoedTrainingSession()
.
Но я думаю, что оба подхода не идеальны.
Не могли бы вы предложить, какое было бы лучшее решение в этом случае?