Давайте разберем ваш пример кода:
Первая используемая вами функция - map_fn
. Map_fn
разделит тензор по первому измерению и передаст эти отдельные тензоры в его внутреннюю функцию. Эта функция не доставляет вам никаких проблем. Далее tf.cond
. tf.cond
ожидает скалярное значение в своем предикате. Чтобы разбить:
tensor = tf.constant([[0, 1, 0], [0.5, 0.5, 0.3]], tf.float32)
cond_val1 = tf.math.not_equal(tensor, 0.0)
print(cond_val1.shape) # Shape (2, 3)
cond_val1
в приведенном выше примере явно тензор. Вам придется использовать tf.reduce_all
или tf.reduce_any
, чтобы преобразовать его в скаляр. Тогда вы получите скаляр, необходимый для tf.cond
. Например:
cond_val2 = tf.reduce_all(tf.math.not_equal(tensor, 0.0))
print(cond_val2.shape) # Shape ()
Теперь это заставит tf.cond
работать. Но есть еще одна проблема, которую вы получаете. Вы уже потеряли способность обрабатывать тензор поэлементно.
Во-вторых, с помощью map_fn
вы пропускаете весь тензорный разбиение в первом измерении, которое в вашем случае будет [0, 1, 0]
и [0.5, 0.5, 0.3]
. Но проблема в том, что у tf.cond
true_fn
и false_fn
нет возможности обрабатывать его поэлементно.
Надеюсь, вы узнали о множестве проблем, которые присутствуют в вашем коде.