Вы можете просто нарезать их и передать в виде индексов, как в:
In [193]: idxs = torch.nonzero(a == 1)
In [194]: c = b[idxs[:, 0], idxs[:, 1]]
In [195]: c
Out[195]:
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
[0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
[0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
В качестве альтернативы, еще более простой и предпочтительный для меня подход - просто использовать torch.where()
, а затем непосредственно индексировать в тензор b
, как в:
In [196]: b[torch.where(a == 1)]
Out[196]:
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
[0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
[0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
Немного более подробного объяснения вышеуказанного подхода к использованию torch.where()
: он работает на основе концепции расширенного индексирования .То есть, когда мы индексируем в тензор, используя кортеж объектов последовательности, таких как кортеж тензоров, кортеж списков, кортеж кортежей и т. Д.
# some input tensor
In [207]: a
Out[207]:
tensor([[12., 1., 0., 0.],
[ 4., 9., 21., 1.],
[10., 2., 1., 0.]])
Для базового среза нам понадобится кортеж целого числаиндексы:
In [212]: a[(1, 2)]
Out[212]: tensor(21.)
Чтобы добиться того же с помощью расширенного индексирования, нам понадобится кортеж объектов последовательности:
# adv. indexing using a tuple of lists
In [213]: a[([1,], [2,])]
Out[213]: tensor([21.])
# adv. indexing using a tuple of tuples
In [215]: a[((1,), (2,))]
Out[215]: tensor([21.])
# adv. indexing using a tuple of tensors
In [214]: a[(torch.tensor([1,]), torch.tensor([2,]))]
Out[214]: tensor([21.])
И размер возвращаемого тензора всегда будет на единицу меньшечем размерность входного тензора.