Я задал вопрос и реализовывал решение, когда обнаружил, что операция tf.math.count_nonzero
не имеет определенного градиента. Итак, я попробовал следующий раунд о методе:
eps = 1e-6
a = tf.ones((4, 4, 2, 2), tf.float32)
h = tf.linalg.svd(a, full_matrices=False, compute_uv=False)
cond = tf.less(h, eps)
h = tf.where(cond, tf.zeros(tf.shape(h)), h)
i = tf.reduce_sum(h, axis=-1)
j = h[:, :, 0]
rank_mat = tf.multiply(2., tf.ones((4, 4)))
cond = tf.not_equal(i, j)
rank_mat = tf.where(cond, rank_mat, tf.ones(tf.shape(rank_mat)))
cond = tf.equal(i, tf.zeros(shape=tf.shape(i), dtype=tf.float32))
rank_mat = tf.where(cond, tf.zeros(tf.shape(rank_mat)), rank_mat)
min_rank = tf.reduce_min(rank_mat)
Все еще та же ошибка сохраняется. Я частично понимаю, почему это происходит, но существует ли дифференцированный способ реализации этого? Спасибо.