Если я хочу проверить, что все элементы в tnsr = tf.constant(...)
больше 3, я могу tf.reduce_all(tnsr > 3)
получить скалярный логический тензор. Если я использую активное выполнение или @tf.function
, я могу использовать это как обычно bool
:
@tf.function
def foo(tnsr):
if tf.reduce_all(tnsr > 3):
...
, но этот не работает с autograph=False
. Что мне тогда делать?
Другие вещи, которые я пробовал:
tf.cond
возвращает tf.Tensor
даже для true_fn=lambda: True
bool(tf.reduce_all(tnsr > 3))
не работает по той же причине, что и выше