Как извлечь ненулевые значения из тензора в керасе - PullRequest
1 голос
/ 30 апреля 2020

Я пытаюсь манипулировать некоторыми данными в Python внутри пользовательской функции потерь в Tensorflow.keras

Рассмотрим следующий пример:

b = tf.constant([[0, 3, 1], [0, 5, 2]])

Я хотел бы стереть нулевой столбец или извлечь ненулевой столбец так, чтобы конечным результатом был тензор

[[3,1], [5,2]]

Я пытался с tf.where используя маску, но она не поддерживает форму, она просто возвращает одномерный тензор с ненулевыми значениями. Кроме того, мне нужно, чтобы это работало для произвольного числа строк, единственное исправленное число столбцов.

1 Ответ

3 голосов
/ 30 апреля 2020

это выбирает все столбцы с суммой> 0:

tf.transpose(tf.gather_nd(tf.transpose(b), tf.where(tf.reduce_sum(b, axis=0)>0)))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...