Безопасно проверьте, удовлетворяет ли tf.Tensor какому-либо условию - PullRequest
2 голосов
/ 16 июня 2020

Если я хочу проверить, что все элементы в tnsr = tf.constant(...) больше 3, я могу tf.reduce_all(tnsr > 3) получить скалярный логический тензор. Если я использую активное выполнение или @tf.function, я могу использовать это как обычно bool:

@tf.function
def foo(tnsr):
    if tf.reduce_all(tnsr > 3):
        ...

, но этот не работает с autograph=False . Что мне тогда делать?

Другие вещи, которые я пробовал:

  • tf.cond возвращает tf.Tensor даже для true_fn=lambda: True
  • bool(tf.reduce_all(tnsr > 3)) не работает по той же причине, что и выше
...