Получить индексы элементов в тензоре а, присутствующих в тензоре б - PullRequest
2 голосов
/ 29 марта 2020

Например, я хочу получить индексы элементов со значениями 0 и 2 в тензоре a. Эти значения (0 и 2) хранятся в тензоре b. Я разработал способ Pythoni c для этого (показан ниже), но я не думаю, что списочные вычисления оптимизированы для работы на графическом процессоре, или, возможно, есть более PyTorchy способ сделать это, чего я не знаю.

import torch
a = torch.tensor([0, 1, 0, 1, 1, 0, 2])
b = torch.tensor([0, 2])
torch.tensor([x in b for x in a]).nonzero()

>>>> tensor([[0],
             [2],
             [5],
             [6]])

Любые другие предложения или это приемлемый способ?

1 Ответ

2 голосов
/ 30 марта 2020

Вот более эффективный способ сделать это (как предлагается в ссылке, размещенной jodag в комментариях ...):

(a[..., None] == b).any(-1).nonzero()
...