PyTorch - лучший способ вернуть оригинальный тензорный порядок после torch. - PullRequest
0 голосов
/ 01 сентября 2018

Я хочу вернуть исходный тензорный порядок после операции torch.sort и некоторых других модификаций отсортированного тензора, чтобы тензор больше не сортировался. Лучше объяснить это на примере:

x = torch.tensor([30., 40., 20.])
ordered, indices = torch.sort(x)
# ordered is [20., 30., 40.]
# indices is [2, 0, 1]
ordered = torch.tanh(ordered) # it doesn't matter what operation is
final = original_order(ordered, indices) 
# final must be equal to torch.tanh(x)

Я реализовал функцию следующим образом:

def original_order(ordered, indices):
    z = torch.empty_like(ordered)
    for i in range(ordered.size(0)):
        z[indices[i]] = ordered[i]
    return z

Есть ли лучший способ сделать это? В частности, можно ли избежать цикла и более эффективно вычислить операцию?

В моем случае у меня есть тензор размером torch.Size([B, N]), и я сортирую каждую из B строк отдельно с помощью одного вызова torch.sort. Итак, мне нужно вызвать original_order B раз с другим циклом.

Есть еще идеи, идеи?

РЕДАКТИРОВАТЬ 1 - Избавиться от внутренней петли

Я решил часть проблемы, просто проиндексировав z с помощью индексов следующим образом:

def original_order(ordered, indices):
    z = torch.empty_like(ordered)
    z[indices] = ordered
    return z

Теперь я просто должен понять, как избежать внешнего цикла в B измерении.

РЕДАКТИРОВАТЬ 2 - Избавиться от внешней петли

def original_order(ordered, indices, batch_size):
    # produce a vector to shift indices by lenght of the vector 
    # times the batch position
    add = torch.linspace(0, batch_size-1, batch_size) * indices.size(1)


    indices = indices + add.long().view(-1,1)

    # reduce tensor to single dimension. 
    # Now the indices take in consideration the new length
    long_ordered = ordered.view(-1)
    long_indices = indices.view(-1)

    # we are in the previous case with one dimensional vector
    z = torch.zeros_like(long_ordered).float()
    z[long_indices] = long_ordered

    # reshape to get back to the correct dimension
    return z.view(batch_size, -1)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...