Как работает параметр 'dim' в torch.unique ()? - PullRequest
0 голосов
/ 19 января 2019

Я пытаюсь извлечь уникальные значения в каждой строке матрицы и вернуть их в одну и ту же матрицу (с повторяющимися значениями, скажем, 0). Например, я хотел бы преобразовать

torch.Tensor(([1, 2, 3, 4, 3, 3, 4],
              [1, 6, 3, 5, 3, 5, 4]])

до

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 0, 0, 4]])

или

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 4, 0, 0]])

Т.е. порядок в строках не имеет значения.Я пытался использовать pytorch.unique(), и в документации упоминается, что измерение для получения уникальных значений может быть указано с помощью параметра dim.Тем не менее, это не похоже на работу в этом случае.

Я пробовал:

output= torch.unique(torch.Tensor([[4,2,52,2,2],[5,2,6,6,5]]), dim = 1)

output

Что дает

tensor([[ 2.,  2.,  2.,  4., 52.],
        [ 2.,  5.,  6.,  5.,  6.]])

У кого-нибудь есть конкретное исправление для этого?Если возможно, я стараюсь избегать петель.

1 Ответ

0 голосов
/ 25 января 2019

Следует признать, что функция unique может иногда приводить в замешательство, не приводя надлежащих примеров и объяснений.

Параметр dim указывает, к какому измерению тензор матрицы вы хотите применить.

Например, в двумерной матрице dim=0 позволит операции выполнять вертикально, где dim=1 означает горизонтально.

В качестве примера рассмотрим матрицу 4x4 с dim=1.Как вы можете видеть из моего кода ниже, операция unique применяется строка за строкой.

Вы заметили двойное вхождение числа 11 в первом и последнем ряду.Numpy и Torch делают это, чтобы сохранить форму окончательной матрицы.

Однако, если вы не укажете какое-либо измерение, torch автоматически сгладит вашу матрицу, а затем применит к ней unique, и вы получите одномерный массив, содержащий уникальные данные.

import torch

m = torch.Tensor([
    [11, 11, 12,11], 
    [13, 11, 12,11], 
    [16, 11, 12, 11],  
    [11, 11, 12, 11]
])

output, indices = torch.unique(m, sorted=True, return_inverse=True, dim=1)
print("Ori \n{}".format(m.numpy()))
print("Sorted \n{}".format(output.numpy()))
print("Indices \n{}".format(indices.numpy()))

# without specifying dimension
output, indices = torch.unique(m, sorted=True, return_inverse=True)
print("Sorted (no dim) \n{}".format(output.numpy()))

Результат (dim = 1)

Ori
[[11. 11. 12. 11.]
 [13. 11. 12. 11.]
 [16. 11. 12. 11.]
 [11. 11. 12. 11.]]
Sorted
[[11. 11. 12.]
 [11. 13. 12.]
 [11. 16. 12.]
 [11. 11. 12.]]
Indices
[1 0 2 0]

Результат (без измерений)

Sorted (no dim)
[11. 12. 13. 16.]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...