Tensorflow: индексирование без явных номеров строк при одинаковом количестве строк - PullRequest
0 голосов
/ 24 января 2020

У меня есть двумерный массив индексов, каждая строка содержит массив элементов, которые необходимо переставить в каждой строке в матрице mask. Следует отметить, что количество строк как в mask, так и в indices равно.

mask = tf.random.uniform((5, 5)) > 0.5
indices = tf.random.uniform((5, 3), minval=0, maxval=5, dtype=tf.int64)
mask, indices
(<tf.Tensor: shape=(5, 5), dtype=bool, numpy=
 array([[False, False,  True,  True,  True],
        [False, False, False, False,  True],
        [False,  True,  True, False, False],
        [False, False,  True, False,  True],
        [ True,  True, False, False, False]])>,
 <tf.Tensor: shape=(5, 3), dtype=int64, numpy=
 array([[2, 4, 3],
        [3, 1, 1],
        [0, 0, 4],
        [1, 1, 0],
        [0, 3, 0]])>)

Поэтому я ожидаю результат, такой что операция будет

[(mask[0])[indices[0]],      #  <--- mask[0][2,4,3] i.e. permute element [2,4,3] in mask[0]
 (mask[1])[indices[1]],
 ...
 (mask[-1])[indices[-1]]]
array([[ True,  True,  True],
       [False, False, False],
       [False, False, False],
       [False, False, False],
       [ True, False,  True]])>

Как мне сделать это эффективно?

В данный момент я добавляю индекс строки к каждому элементу в indices формы [5,3], чтобы сделать его [5,3,2]

# same as indices, with row numbers added to each element
# i.e. [2, 4, 3] becomes [[0, 2], [0, 4], [0, 3]]
#      [3, 1, 1] becomes [[1, 3], [1, 1], [1, 1]], etc.

print(indices_adjusted)
<tf.Tensor: shape=(5, 3, 2), dtype=int64, numpy=
array([[[0, 2],
        [0, 4],
        [0, 3]],

       [[1, 3],
        [1, 1],
        [1, 1]],

       [[2, 0],
        [2, 0],
        [2, 4]],

       [[3, 1],
        [3, 1],
        [3, 0]],

       [[4, 0],
        [4, 3],
        [4, 0]]])>

Так что я могу передать это непосредственно tf.gather_nd. Есть ли лучший способ сделать это, вместо того, чтобы ставить номера строк (возможно, с помощью некоторых магических c трансляций?)

Colab Plaground здесь

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...