Индексирование в PyTorch почти похоже на numpy.
a = torch.randn(2, 2, 3)
b = torch.eye(2, 2, dtype=torch.long)
c = torch.eye(2, 2, dtype=torch.long)
print(a)
print(a[b, c, :])
tensor([[[ 1.2471, 1.6571, -2.0504],
[-1.7502, 0.5747, -0.3451]],
[[-0.4389, 0.4482, 0.7294],
[-1.3051, 0.6606, -0.6960]]])
tensor([[[-1.3051, 0.6606, -0.6960],
[ 1.2471, 1.6571, -2.0504]],
[[ 1.2471, 1.6571, -2.0504],
[-1.3051, 0.6606, -0.6960]]])