Генераторы являются многопоточными, поэтому график, используемый внутри генератора, будет работать в другом потоке, чем тот, который создал график. Таким образом, доступ к генератору форм модели приведет к доступу к другому графику. Простое (но плохое) решение состоит в том, чтобы заставить генератор работать в том же потоке, что и тот, который создал график, установив workers=0
.
model.fit_generator(gen_data(), len(X_train)//batch_size, 1, validation_data=None, workers=0))
Код отладки:
def gen_data():
print ("-->",tf.get_default_graph())
while True:
for i in range(1):
yield (np.random.randn(batch_size, num_steps, num_input),
np.random.randn(batch_size, num_steps, 8))
model = get_model()
print (tf.get_default_graph())
model.fit_generator(gen_data(), 8, 1)
print (tf.get_default_graph())
выход
<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
--><tensorflow.python.framework.ops.Graph object at 0x14388e5c0>
Epoch 1/1
8/8 [==============================] - 4s 465ms/step - loss: 1.0198 - acc: 0.1575
<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
Вы можете видеть, что графические объекты отличаются. Выполнение workers=0
заставит генератор работать однопоточным.
Использование
model.fit_generator(gen_data(), 8, 1, workers=0)
Результаты в
<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
--> <tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
Epoch 1/1
8/8 [==============================] - 4s 466ms/step - loss: 1.0373 - acc: 0.0975
<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
тот же однопоточный генератор, имеющий доступ к тому же графику.
Однако, чтобы включить многопоточный генератор, элегантным методом было бы сохранить граф в переменную в основном процессе создания графа и передать его генератору, который использует переданный граф в качестве графа по умолчанию.