назначение элемента pytorch с использованием тензора многомерного индекса - PullRequest
0 голосов
/ 05 апреля 2020

У меня есть несколько тензоров 3-го порядка, все с одинаковыми первыми измерениями batch_size, seq_length:

  1. " output " со случайными значениями, последнее измерение является словарем измерение, которое я хотел бы проиндексировать.

  2. " indexer ", LongTensor с такими же значениями яркости, что и " output ", за исключением последнего измерения не словарный запас, а размер базы знаний, содержащей индексы токенов. Я бы хотел, чтобы индексы в диапазоне от 0 ... vocab_size индексировали первый Tensor.

  3. " values ​​", такие же тусклые значения, что и " indexer ». Содержит случайные значения; Я хочу добавить значение, содержащееся здесь, в pos " values ​​[x, y, z] ", чтобы добавить его к индексу " output ", заданному output [x, y, indexer [x, y, z]].

Как мне избежать этого для l oop с помощью методов pytorch?

batch, sequence, kb = values.shape
for x in range(batch):
    for y in range(sequence):
        for z in range(kb):
            output[x,y,indexer[x,y,z]] += values[x,y,z]

Что я пробовал :

In [52]: import torch
In [53]: output = torch.rand(3,12,40000) #batch x seq_len x vocab_size
In [54]: indexer = torch.randint(40000, (3,12,52)) #batch x seq_len, knowledgebase_size
In [55]: values = torch.rand(3,12,52) #batch x seq_len, knowledgebase_size
In [56]: x[indexer] += values                                                          
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-56-af06e3bd5b23> in <module>
----> 1 x[indexer] += values

IndexError: index 28 is out of bounds for dimension 0 with size 3
In [57]: x[:,:,indexer].shape
Out[57]: torch.Size([3, 12, 3, 12, 52])

Я хочу сделать это без использования циклов for и желательно на месте.

Не удалось найти точное соответствие для этого вопроса или указать c объяснение там , Поскольку кажется, что индексирование работает одинаково для всех факелов и numpy, я думаю, что этот вопрос применим и там (смотрите в документации по np):

Есть идеи?

...