Установка переменных Keras в генераторе - PullRequest
4 голосов
/ 12 апреля 2019

Я хочу установить свое скрытое состояние LSTM в генераторе. Однако набор состояния работает только вне генератора:

K.set_value(model.layers[0].states[0], np.random.randn(batch_size,num_outs)) # this works

def gen_data():
    x = np.zeros((batch_size, num_steps, num_input))
    y = np.zeros((batch_size, num_steps, num_output))
    while True:
        for i in range(batch_size):
            K.set_value(model.layers[0].states[0], np.random.randn(batch_size,num_outs)) # error
            x[i, :, :] = X_train[gen_data.current_idx]
            y[i, :, :] = Y_train[gen_data.current_idx]
            gen_data.current_idx += 1
        yield x, y
gen_data.current_idx = 0

Генератор вызывается в функции fit_generator:

model.fit_generator(gen_data(), len(X_train)//batch_size, 1, validation_data=None)

Это результат печати состояния:

print(model.layers[0].states[0])
<tf.Variable 'lstm/Variable:0' shape=(1, 2) dtype=float32>

Это ошибка, которая возникает в генераторе:

ValueError: Tensor("Placeholder_1:0", shape=(1, 2), dtype=float32) must be from the same graph as Tensor("lstm/Variable:0", shape=(), dtype=resource)

Что я делаю не так?

1 Ответ

1 голос
/ 12 апреля 2019

Генераторы являются многопоточными, поэтому график, используемый внутри генератора, будет работать в другом потоке, чем тот, который создал график. Таким образом, доступ к генератору форм модели приведет к доступу к другому графику. Простое (но плохое) решение состоит в том, чтобы заставить генератор работать в том же потоке, что и тот, который создал график, установив workers=0.

model.fit_generator(gen_data(), len(X_train)//batch_size, 1, validation_data=None, workers=0))

Код отладки:

def gen_data():
    print ("-->",tf.get_default_graph())
    while True:
        for i in range(1):
            yield (np.random.randn(batch_size, num_steps, num_input), 
            np.random.randn(batch_size, num_steps, 8))

model = get_model()
print (tf.get_default_graph())
model.fit_generator(gen_data(), 8, 1)
print (tf.get_default_graph())

выход

<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
--><tensorflow.python.framework.ops.Graph object at 0x14388e5c0>
Epoch 1/1 
8/8 [==============================] - 4s 465ms/step - loss: 1.0198 - acc: 0.1575
<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>

Вы можете видеть, что графические объекты отличаются. Выполнение workers=0 заставит генератор работать однопоточным.

Использование

model.fit_generator(gen_data(), 8, 1, workers=0)

Результаты в

<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
--> <tensorflow.python.framework.ops.Graph object at 0x1228a5e80>
Epoch 1/1
8/8 [==============================] - 4s 466ms/step - loss: 1.0373 - acc: 0.0975
<tensorflow.python.framework.ops.Graph object at 0x1228a5e80>

тот же однопоточный генератор, имеющий доступ к тому же графику.

Однако, чтобы включить многопоточный генератор, элегантным методом было бы сохранить граф в переменную в основном процессе создания графа и передать его генератору, который использует переданный граф в качестве графа по умолчанию.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...