Tensorflow.py Защищенное подразделение - PullRequest
0 голосов
/ 03 февраля 2019

Я пытаюсь реализовать своего рода защищенное разделение, используя Tensorflow.where, но почему-то кажется, что оно пропускает условие, установленное в операторе where.

Основная идея заключается в том, что при делении x/y, если y == 0., то результат деления будет x вместо броска и ошибки.

Мой код выглядит следующим образом:

def Pdivide(x,y):
    result = tf.where(y == 0., x, x/y) 
    return result

Но каким-то образом это условие пропускается:

>>> a = tf.Variable([1.7, 0.2, 0., 1.1, 0.9, 0.3, 23., -1.])
>>> b = tf.Variable([0., 0., 0., 1., 1., 0., 1., 1.])

>>>Pdivide(a,b)

>>>(inf, inf, nan, 1.1, 0.9, inf, 23, -1)

Предполагаемый вывод:

>>>(1.7, 0.2, 0., 1.1, 0.9, 0.3, 23, -1)

PS: использование eager выполнения.

1 Ответ

0 голосов
/ 03 февраля 2019

Хорошо, так что ответ довольно прост.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...