Я хочу вернуть исходный тензорный порядок после операции torch.sort
и некоторых других модификаций отсортированного тензора, чтобы тензор больше не сортировался. Лучше объяснить это на примере:
x = torch.tensor([30., 40., 20.])
ordered, indices = torch.sort(x)
# ordered is [20., 30., 40.]
# indices is [2, 0, 1]
ordered = torch.tanh(ordered) # it doesn't matter what operation is
final = original_order(ordered, indices)
# final must be equal to torch.tanh(x)
Я реализовал функцию следующим образом:
def original_order(ordered, indices):
z = torch.empty_like(ordered)
for i in range(ordered.size(0)):
z[indices[i]] = ordered[i]
return z
Есть ли лучший способ сделать это? В частности, можно ли избежать цикла и более эффективно вычислить операцию?
В моем случае у меня есть тензор размером torch.Size([B, N])
, и я сортирую каждую из B
строк отдельно с помощью одного вызова torch.sort
. Итак, мне нужно вызвать original_order
B
раз с другим циклом.
Есть еще идеи, идеи?
РЕДАКТИРОВАТЬ 1 - Избавиться от внутренней петли
Я решил часть проблемы, просто проиндексировав z с помощью индексов следующим образом:
def original_order(ordered, indices):
z = torch.empty_like(ordered)
z[indices] = ordered
return z
Теперь я просто должен понять, как избежать внешнего цикла в B
измерении.
РЕДАКТИРОВАТЬ 2 - Избавиться от внешней петли
def original_order(ordered, indices, batch_size):
# produce a vector to shift indices by lenght of the vector
# times the batch position
add = torch.linspace(0, batch_size-1, batch_size) * indices.size(1)
indices = indices + add.long().view(-1,1)
# reduce tensor to single dimension.
# Now the indices take in consideration the new length
long_ordered = ordered.view(-1)
long_indices = indices.view(-1)
# we are in the previous case with one dimensional vector
z = torch.zeros_like(long_ordered).float()
z[long_indices] = long_ordered
# reshape to get back to the correct dimension
return z.view(batch_size, -1)