Tf.where не оценивает - PullRequest
       1

Tf.where не оценивает

0 голосов
/ 02 сентября 2018
sess = tf.InteractiveSession()
t = tf.expand_dims(tf.constant(list(range(9))), axis=1)
tf.where(t == 5).eval()

InvalidArgumentError (see above for traceback): WhereOp : Unhandled input dimensions: 0
     [[Node: Where_16 = Where[T=DT_BOOL, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Where_16/condition)]]

Что здесь происходит? Соответствующий код в Numpy с np.where работает.

1 Ответ

0 голосов
/ 04 декабря 2018

В вашем примере вы оцениваете tf.where(False), так как оператор == не перегружен для тензоров. (Более подробная информация, например, здесь: Перегрузка оператора TensorFlow )

Попробуйте:

sess = tf.InteractiveSession()
t = tf.expand_dims(tf.constant(list(range(9))), axis=1)
tf.where(tf.equal(t, 5)).eval()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...