TLDR;
Мы можем создать функцию, которая маскирует все строки, кроме верхних k
элементов, следующим образом:
def mask_all_but_top_k(X, k):
n = X.shape[1]
top_k_indices = tf.math.top_k(X, k).indices
mask = tf.reduce_sum(tf.one_hot(top_k_indices, n), axis=1)
return mask * X
К сожалению, tf.map.top_k
не позволяет нам указать размер, но мы, конечно, можем повторить этот столбец, сначала транспонировав X
, а затем транспонировав результат с tf.transpose()
Объяснение
Мы можем достичь этого, создав маску из единиц и нулей, а затем умножив их поэлементно.
Так, например, рассмотрим случай, когда n=4, k=2
и у нас есть следующая матрица:
array([[0.67757607, 0.74070597, 0.89508283, 0.11858773],
[0.7661159 , 0.8737055 , 0.73599136, 0.1552105 ],
[0.7093129 , 0.44203556, 0.48861897, 0.83231044],
[0.24682868, 0.36648738, 0.92984104, 0.9881872 ]], dtype=float32)
тогда мы можем использовать функцию tf.math.top_k
, чтобы получить индексы двух верхних значений в каждой строке матрицы:
top_k_indices = tf.math.top_k(X, 2).indices
Теперь мы используем небольшую хитрость, чтобы сначала one_hot
закодировать их:
tf.one_hot(top_k_indices, 4)
array([[[0., 0., 1., 0.],
[0., 1., 0., 0.]],
[[0., 1., 0., 0.],
[1., 0., 0., 0.]],
[[0., 0., 0., 1.],
[1., 0., 0., 0.]],
[[0., 0., 0., 1.],
[0., 0., 1., 0.]]], dtype=float32)>
затем reduce_sum
через второе и последнее измерение, чтобы создать нашу маску:
tf.reduce_sum(tf.one_hot(top_k_indices, 4), axis=1)
array([[0., 1., 1., 0.],
[1., 1., 0., 0.],
[1., 0., 0., 1.],
[0., 0., 1., 1.]], dtype=float32)>
теперь мы можем просто выполнить умножение Адамара (поэлементно), чтобы получить желаемый результат:
array([[0. , 0.74070597, 0.89508283, 0. ],
[0.7661159 , 0.8737055 , 0. , 0. ],
[0.7093129 , 0. , 0. , 0.83231044],
[0. , 0. , 0.92984104, 0.9881872 ]], dtype=float32)>
Собрав все это вместе, мы можем создать функцию, которая маскирует все строки, кроме верхних k
элементов, следующим образом:
def mask_all_but_top_k(X, k):
n = X.shape[1]
top_k_indices = tf.math.top_k(X, k).indices
mask = tf.reduce_sum(tf.one_hot(top_k_indices, n), axis=1)
return mask * X