Как получить индексы нескольких элементов в двумерном тензоре, дружественным к графическому процессору? - PullRequest
0 голосов
/ 05 февраля 2020

Этот вопрос похож на тот, на который уже дан ответ здесь , но этот вопрос не касается того, как получить индексы нескольких элементов.

У меня есть 2D-тензор points со многими строк и небольшое количество столбцов, и хотел бы получить тензор, содержащий индексы строк всех элементов в этом тензоре. Я знаю, какие элементы присутствуют в points заранее; Он содержит целочисленные элементы в диапазоне от 0 до 999, и я могу создать тензор, используя функцию диапазона, чтобы отразить множество возможных элементов. Элементы могут находиться в любом из столбцов.

Как я могу получить индексы строк, где каждый элемент появляется в моем тензоре, таким образом, чтобы избежать зацикливания или использования numpy, поэтому я могу сделать это быстро на GPU?

Я ищу что-то вроде (points == elements).nonzero()[:,1]

Спасибо!

Ответы [ 2 ]

0 голосов
/ 05 февраля 2020

Я не уверен, правильно ли я понимаю, что вы ищете, но если вам нужны индексы определенного значения, вы можете попробовать использовать where и разреженное представление результата.

Например, в приведенном ниже тензоре points значение 998 присутствует в индексах [0,0] и [2,0]. Чтобы получить эти показатели можно:

In [34]: points=torch.tensor([ [998,  6], [1, 3], [998, 999], [2, 3] ] )

In [35]: torch.where(points==998, points, torch.tensor(0)).to_sparse().indices()
Out[35]:
tensor([[0, 2],
        [0, 0]])
0 голосов
/ 05 февраля 2020

попробуй torch.cat([(t == i).nonzero() for i in elements_to_compare])

>>> import torch
>>> t = torch.empty((15,4)).random_(0, 999)
>>> t
tensor([[429., 833., 393., 828.],
        [555., 893., 846., 909.],
        [ 11., 861., 586., 222.],
        [232.,  92., 576., 452.],
        [171., 341., 851., 953.],
        [ 94.,  46., 130., 413.],
        [243., 251., 545., 331.],
        [620.,  29., 194., 176.],
        [303., 905., 771., 149.],
        [482., 225.,   7., 315.],
        [ 44., 547., 206., 299.],
        [695.,   7., 645., 385.],
        [225., 898., 677., 693.],
        [746.,  21., 505., 875.],
        [591., 254.,  84., 888.]])
>>> torch.cat([(t == i).nonzero() for i in [7,385]])
tensor([[ 9,  2],
        [11,  1],
        [11,  3]])

>>> torch.cat([(t == i).nonzero()[:,1] for i in [7,385]])
tensor([2, 1, 3])

Numpy:

>>> np.nonzero(np.isin(t, [7,385]))
(array([ 9, 11, 11], dtype=int64), array([2, 1, 3], dtype=int64))

>>> np.nonzero(np.isin(t, [7,385]))[1]
array([2, 1, 3], dtype=int64)

...