Вычисление инверсии тензора с использованием tf.map_fn - PullRequest
1 голос
/ 02 октября 2019

Использование Tensorflow 1.4 Я хочу вычислить поэлементно обратное (x -> 1 / x) тензора, используя функции отображения. Если значение элемента в Tensor равно нулю, я хочу получить нулевой вывод.

В качестве примера для tensor: [[0, 1, 0], [0.5, 0.5, 0.3]], я хочу получить output: [[0,1,0], [2, 2, 3.333]]. Я знаю, что могу легко получить желаемый результат, используя tf.math.reciprocal_no_nan () в tf2.0 и tf.math.divide_no_nan() в tf 1.4, но мне интересно, почему следующий код не работает:

tensor = tf.constant([[0, 1, 0], [0.5, 0.5, 0.3]], tf.float32)
tensor_inverse = tf.map_fn(lambda x: tf.cond(tf.math.not_equal(x, 0.0), lambda x: 1/x, lambda: 0) , tensor)

Я получаю эту ошибку:

ValueError: Значение истинности массива с более чем одним элементом является неоднозначным. Используйте a.any () или a.all ()

1 Ответ

0 голосов
/ 02 октября 2019

Давайте разберем ваш пример кода:

Первая используемая вами функция - map_fn. Map_fn разделит тензор по первому измерению и передаст эти отдельные тензоры в его внутреннюю функцию. Эта функция не доставляет вам никаких проблем. Далее tf.cond. tf.cond ожидает скалярное значение в своем предикате. Чтобы разбить:

tensor = tf.constant([[0, 1, 0], [0.5, 0.5, 0.3]], tf.float32)
cond_val1 = tf.math.not_equal(tensor, 0.0)
print(cond_val1.shape)  # Shape (2, 3)

cond_val1 в приведенном выше примере явно тензор. Вам придется использовать tf.reduce_all или tf.reduce_any, чтобы преобразовать его в скаляр. Тогда вы получите скаляр, необходимый для tf.cond. Например:

cond_val2 = tf.reduce_all(tf.math.not_equal(tensor, 0.0))
print(cond_val2.shape)  # Shape ()

Теперь это заставит tf.cond работать. Но есть еще одна проблема, которую вы получаете. Вы уже потеряли способность обрабатывать тензор поэлементно.

Во-вторых, с помощью map_fn вы пропускаете весь тензорный разбиение в первом измерении, которое в вашем случае будет [0, 1, 0] и [0.5, 0.5, 0.3]. Но проблема в том, что у tf.cond true_fn и false_fn нет возможности обрабатывать его поэлементно.

Надеюсь, вы узнали о множестве проблем, которые присутствуют в вашем коде.

...