Как эффективно извлечь индексы максимальных значений в тензоре Факела? - PullRequest
0 голосов
/ 08 ноября 2018

Предположим, у нас есть тензор факела, например, следующей формы:

x = torch.rand(20, 1, 120, 120)

Теперь мне хотелось бы получить индексы максимальных значений каждой матрицы 120x120. Чтобы упростить задачу, я бы сначала x.squeeze() поработал с формой [20, 120, 120]. Затем я хотел бы получить тензор факела, который представляет собой список индексов с формой [20, 2].

Как я могу сделать это быстро?

1 Ответ

0 голосов
/ 08 ноября 2018

Если я правильно вас понял, вам нужны не значения, а индексы.К сожалению, нет готового решения.Существует функция argmax(), но я не могу понять, как заставить ее делать именно то, что вы хотите.

Итак, вот небольшой обходной путь, эффективность также должна быть в порядке, поскольку мы просто делим тензоры:

n = torch.tensor(4)
d = torch.tensor(4)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
# since argmax() does only return the index of the flattened
# matrix block we have to calculate the indices by ourself 
# by using / and % (// would also work, but as we are dealing with
# type torch.long / works as well
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
print(x)
print(indices)

n представляет ваше первое измерение, а d последние два измерения.Я беру меньшие числа здесь, чтобы показать результат.Но, конечно, это также будет работать для n=20 и d=120:

n = torch.tensor(20)
d = torch.tensor(120)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
#print(x)
print(indices)

Вот вывод для n=4 и d=4:

tensor([[[[0.3699, 0.3584, 0.4940, 0.8618],
          [0.6767, 0.7439, 0.5984, 0.5499],
          [0.8465, 0.7276, 0.3078, 0.3882],
          [0.1001, 0.0705, 0.2007, 0.4051]]],


        [[[0.7520, 0.4528, 0.0525, 0.9253],
          [0.6946, 0.0318, 0.5650, 0.7385],
          [0.0671, 0.6493, 0.3243, 0.2383],
          [0.6119, 0.7762, 0.9687, 0.0896]]],


        [[[0.3504, 0.7431, 0.8336, 0.0336],
          [0.8208, 0.9051, 0.1681, 0.8722],
          [0.5751, 0.7903, 0.0046, 0.1471],
          [0.4875, 0.1592, 0.2783, 0.6338]]],


        [[[0.9398, 0.7589, 0.6645, 0.8017],
          [0.9469, 0.2822, 0.9042, 0.2516],
          [0.2576, 0.3852, 0.7349, 0.2806],
          [0.7062, 0.1214, 0.0922, 0.1385]]]])
tensor([[0, 3],
        [3, 2],
        [1, 1],
        [1, 0]])

Я надеюсь, что этоэто то, что вы хотели получить!:)

Редактировать:

Вот немного измененный, который может быть минимально быстрее (не так много, я думаю :), но этонемного проще и красивее:

Вместо этого, как раньше:

m = x.view(n, -1).argmax(1)
indices = torch.cat(((m // d).view(-1, 1), (m % d).view(-1, 1)), dim=1)

Необходимое изменение уже сделано для значений argmax:

m = x.view(n, -1).argmax(1).view(-1, 1)
indices = torch.cat((m // d, m % d), dim=1)

Но, как уже упоминалосьв комментариях.Я не думаю, что из этого можно получить гораздо больше.

Одна вещь, которую вы могли бы сделать, если действительно важно для вас получить от нее последнее возможное повышение производительности, - реализовать эту функцию как расширение низкого уровня (как в C ++) для pytorch.

Это даст вам только одну функцию, которую вы можете вызвать для нее, и позволит избежать медленного кода Python.

https://pytorch.org/tutorials/advanced/cpp_extension.html

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...