Я пытаюсь манипулировать некоторыми данными в Python внутри пользовательской функции потерь в Tensorflow.keras
Рассмотрим следующий пример:
b = tf.constant([[0, 3, 1], [0, 5, 2]])
Я хотел бы стереть нулевой столбец или извлечь ненулевой столбец так, чтобы конечным результатом был тензор
[[3,1], [5,2]]
Я пытался с tf.where используя маску, но она не поддерживает форму, она просто возвращает одномерный тензор с ненулевыми значениями. Кроме того, мне нужно, чтобы это работало для произвольного числа строк, единственное исправленное число столбцов.