Итак, чтобы сделать выбор multi-index
, вы можете использовать функцию torch.gather , которая собирает значения вдоль оси, заданной параметром dim (второй параметр).
Пример 1:
t2 = torch.tensor([[0.1, 0.2, 0.3, 0.4],
[0.8, 1.8, 0.2, 0.3],
[0.5, 0.1, 0.2, 0.4]])
indexes1 = torch.tensor([[0, 2, 0, 2],
[0, 1, 1, 0],
[0, 0, 1, 2]])
t1 = torch.gather(t2, 0, indexes1) # dim is 0
print(t1)
выход:
tensor([[0.1000, 0.1000, 0.3000, 0.4000],
[0.1000, 1.8000, 0.2000, 0.4000],
[0.1000, 0.2000, 0.2000, 0.4000]])
Пример 2:
t2 = torch.tensor([[0.1, 0.2, 0.3, 0.4],
[0.8, 1.8, 0.2, 0.3],
[0.5, 0.1, 0.2, 0.4]])
indexes2 = torch.tensor([[0, 3, 2, 0],
[0, 1, 1, 3],
[0, 0, 3, 2]])
t1 = torch.gather(t2, 1, indexes2) # dim is 1
print(t1)
выход:
tensor([[0.1000, 0.4000, 0.3000, 0.1000],
[0.8000, 1.8000, 1.8000, 0.3000],
[0.5000, 0.5000, 0.4000, 0.2000]])
Чтобы узнать больше о функции torch.gather
, просто пройдите это ТАКОЕ обсуждение.
Вы также можете использовать torch.Tensor.scatter_
, чтобы сделать то же самое.
t1.scatter_(0, indexes, t2)
в основном говорит, что отправьте элементы тензора t2
по следующим индексам (указанным в indexes
тензор) в тензоре t1
, построчно (dim 0).
Пример:
t1 = torch.zeros((3, 4))
t2 = torch.tensor([[0.1, 0.2, 0.3, 0.4],
[0.8, 1.8, 0.2, 0.3],
[0.5, 0.1, 0.2, 0.4]])
indexes = torch.tensor([[1, 2, 0, 2],
[0, 1, 2, 1],
[2, 0, 1, 0]])
t1 = t1.scatter_(0, indexes, t2)
print(t1)
выход:
tensor([[0.8000, 0.1000, 0.3000, 0.4000],
[0.1000, 1.8000, 0.2000, 0.3000],
[0.5000, 0.2000, 0.2000, 0.4000]])
Подробнее об этом можно прочитать по здесь .