Например, я хочу получить индексы элементов со значениями 0 и 2 в тензоре a
. Эти значения (0 и 2) хранятся в тензоре b
. Я разработал способ Pythoni c для этого (показан ниже), но я не думаю, что списочные вычисления оптимизированы для работы на графическом процессоре, или, возможно, есть более PyTorchy способ сделать это, чего я не знаю.
import torch
a = torch.tensor([0, 1, 0, 1, 1, 0, 2])
b = torch.tensor([0, 2])
torch.tensor([x in b for x in a]).nonzero()
>>>> tensor([[0],
[2],
[5],
[6]])
Любые другие предложения или это приемлемый способ?