Индексирование тензоров Pytorch: как собирать строки по тензорам, содержащим индексы - PullRequest
0 голосов
/ 27 апреля 2019

У меня есть тензоры:

идентификаторы : форма (7000,1), содержащая индексы, такие как [[1], [0], [2], ...]

x : форма (7000, 3 , 255)

Идентификатор тензора кодирует индекс выделенного жирным шрифтом размера x, который следует выбрать. Я хочу собрать выбранные фрагменты в результирующем векторе:

результат: форма (7000,255)

Справочная информация:

У меня есть несколько баллов (форма = (7000,3)) для каждого из 3 элементов, и я хочу выбрать только тот, который набрал наибольшее количество баллов. Поэтому я использовал функцию

ids = torch.argmax(scores,1,True)

давая мне максимальные идентификаторы. Я уже пытался сделать это с помощью функции сбора:

result = x.gather(1,ids)

но это не сработало.

Ответы [ 2 ]

1 голос
/ 27 апреля 2019

Вот решение, которое вы можете найти

ids = ids.repeat(1, 255).view(-1, 1, 255)

Пример, приведенный ниже:

x = torch.arange(24).view(4, 3, 2) 
"""
tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15],
         [16, 17]],

        [[18, 19],
         [20, 21],
         [22, 23]]])
"""
ids = torch.randint(0, 3, size=(4, 1))
"""
tensor([[0],
        [2],
        [0],
        [2]])
"""
idx = ids.repeat(1, 2).view(4, 1, 2) 
"""
tensor([[[0, 0]],

        [[2, 2]],

        [[0, 0]],

        [[2, 2]]])
"""

torch.gather(x, 1, idx) 
"""
tensor([[[ 0,  1]],

        [[10, 11]],

        [[12, 13]],

        [[22, 23]]])
"""
0 голосов
/ 27 апреля 2019

на примере Дэвида Нга я нашел другой способ сделать это:

idx = ids.flatten() + torch.arange(0,4*3,3)

tensor([ 0,  5,  6, 11])



x.view(-1,2)[idx]

tensor([[ 0,  1],
        [10, 11],
        [12, 13],
        [22, 23]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...