попробуй torch.cat([(t == i).nonzero() for i in elements_to_compare])
>>> import torch
>>> t = torch.empty((15,4)).random_(0, 999)
>>> t
tensor([[429., 833., 393., 828.],
[555., 893., 846., 909.],
[ 11., 861., 586., 222.],
[232., 92., 576., 452.],
[171., 341., 851., 953.],
[ 94., 46., 130., 413.],
[243., 251., 545., 331.],
[620., 29., 194., 176.],
[303., 905., 771., 149.],
[482., 225., 7., 315.],
[ 44., 547., 206., 299.],
[695., 7., 645., 385.],
[225., 898., 677., 693.],
[746., 21., 505., 875.],
[591., 254., 84., 888.]])
>>> torch.cat([(t == i).nonzero() for i in [7,385]])
tensor([[ 9, 2],
[11, 1],
[11, 3]])
>>> torch.cat([(t == i).nonzero()[:,1] for i in [7,385]])
tensor([2, 1, 3])
Numpy:
>>> np.nonzero(np.isin(t, [7,385]))
(array([ 9, 11, 11], dtype=int64), array([2, 1, 3], dtype=int64))
>>> np.nonzero(np.isin(t, [7,385]))[1]
array([2, 1, 3], dtype=int64)