Как получить топ-k элементов каждой строки в 2D-тензоре? - PullRequest
0 голосов
/ 10 марта 2020

Как элегантно получить топ-k элементов каждой строки в 2D-тензоре вместо использования for-l oop, как показано ниже?

import torch

elements = torch.rand(5,10)
topk_list = [2,3,1,2,0] # means top2 for 1st row, top3 for 2nd row, top1 for 3rd row,....
index_list = [] # record the topk index in elements

for i in range(5):
    index_list.append(elements[i].topk(topk_list[i]))

Ответы [ 2 ]

1 голос
/ 10 марта 2020

Если ваши k не слишком сильно различаются и вы хотите векторизовать свой код, вы можете сначала взять максимальную вершину k на строку, а затем собрать желаемые результаты.

# Code from OP
import torch

elements = torch.rand(5,10)
topk_list = [2,3,1,2,0] # means top2 for 1st row, top3 for 2nd row, top1 for 3rd row,....
index_list = [] # record the topk index in elements

for i in range(5):
    index_list.append(elements[i].topk(topk_list[i]))

# Print the result
print(index_list)

# Get topk for max_k
max_k = max(topk_list)
topk_vals, topk_inds = elements.topk(max_k, dim=-1)

# Select desired topk using mask
mask = torch.arange(max_k)[None, :] < torch.tensor(topk_list)[:, None]
vals, inds = topk_vals[mask], topk_inds[mask]
rows, _ = mask.nonzero().T
print("-" * 10)
print("rows", rows)
print("inds", inds)
print("vals", vals)

# Or split
vals_per_row = vals.split(topk_list)
inds_per_row = inds.split(topk_list)
print("-" * 10)
print("vals_per_row", vals_per_row)
print("inds_per_row", inds_per_row)

# Or zip (for loop but should be cheap)
index_list = zip(vals_per_row, inds_per_row)
print("-" * 10)
print("zipped results", list(index_list))

Это дает следующий вывод:

[torch.return_types.topk(
values=tensor([0.8148, 0.7443]),
indices=tensor([8, 4])), torch.return_types.topk(
values=tensor([0.7529, 0.7352, 0.6354]),
indices=tensor([8, 1, 9])), torch.return_types.topk(
values=tensor([0.8792]),
indices=tensor([7])), torch.return_types.topk(
values=tensor([0.9626, 0.8728]),
indices=tensor([6, 2])), torch.return_types.topk(
values=tensor([]),
indices=tensor([], dtype=torch.int64))]
----------
rows tensor([0, 0, 1, 1, 1, 2, 3, 3])
inds tensor([8, 4, 8, 1, 9, 7, 6, 2])
vals tensor([0.8148, 0.7443, 0.7529, 0.7352, 0.6354, 0.8792, 0.9626, 0.8728])
----------
vals_per_row (tensor([0.8148, 0.7443]), tensor([0.7529, 0.7352, 0.6354]), tensor([0.8792]), tensor([0.9626, 0.8728]), tensor([]))
inds_per_row (tensor([8, 4]), tensor([8, 1, 9]), tensor([7]), tensor([6, 2]), tensor([], dtype=torch.int64))
----------
zipped results [(tensor([0.8148, 0.7443]), tensor([8, 4])), (tensor([0.7529, 0.7352, 0.6354]), tensor([8, 1, 9])), (tensor([0.8792]), tensor([7])), (tensor([0.9626, 0.8728]), tensor([6, 2])), (tensor([]), tensor([], dtype=torch.int64))]
0 голосов
/ 10 марта 2020

Является ли что-то элегантным или нет, всегда обсуждается. Использование фиксированного диапазона в for l oop определенно может быть улучшено, вы можете, по крайней мере, использовать range(len(topk_list)), чтобы код можно было повторно использовать для различных списков topk.

Вы могли бы улучшить, используя:

for i, n in enumerate(topk_list): 
    index_list.append(elements[i].topk(n))

Или даже:

index_list = [ elements[i].topk(n) for i, n in enumerate(topk_list) ]

Но это всего лишь синтаксический сахар.

...