Постоянно обновлять tf.cond в зависимости от значения bool - PullRequest
0 голосов
/ 31 января 2020

Я использую tf.cond для управления потоком графика Tensorflow. Я просмотрел документацию и смог успешно реализовать ветвление на tf.cond. Но меня беспокоит то, что во время загрузки графика проверяется значение переменной bool, и решение о ветвлении принимается на самом этапе инициализации. Любые дальнейшие изменения в bool не отслеживаются. Ниже приведено MWE, которое лучше описывает проблему:

def funa():
    return tf.constant(32)

def funb():
    return tf.constant(25)

foo = True
x = tf.cond(tf.convert_to_tensor(foo), lambda: funa(), lambda: funb())
for i in range(20):
    global foo
    if i > 10:
        foo = False
    print(sess.run(x))    

Это печатает только 32 с.

Я тоже пытался с eager_execution со следующим кодом:

tf.enable_eager_execution()
def funa():
    return tf.constant(32)

def funb():
    return tf.constant(21)

foo = True
x = tf.cond(tf.convert_to_tensor(foo), lambda: funa(), lambda: funb())
for i in range(20):
    if i > 10:
        foo = False
    print(x)

Все тот же результат.

Так что мой вопрос, как я могу написать код такой, чтобы один часть графика выбирается динамически, исходя из обновлений переменной bool (если возможно)? Спасибо. Я использую Tensorflow v1.14.

1 Ответ

1 голос
/ 31 января 2020

Вы можете сделать заполнитель для foo и указать его значение во время сеанса. Модифицированный код:

import tensorflow as tf

def funa():
    return tf.constant(32)

def funb():
    return tf.constant(25)

foo = True
foo_p = tf.placeholder(tf.bool)

sess = tf.Session()

x = tf.cond(foo_p, lambda: funa(), lambda: funb())
for i in range(20):
    if i > 10:
        foo = False
    print(sess.run(x, {foo_p:foo}))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...