Pytorch: установить индексы в определенном тензорном измерении (аналогично torch.index_select) - PullRequest
0 голосов
/ 31 октября 2018

Я пытаюсь получить и установить индексы в определенном тензорном измерении без изменения формы, если это возможно. Мне удалось найти функцию torch.index_select, которая делает то, что я хочу при получении значений, но я еще не нашел аналогичной функции установки. Существует ли один?

Для контекста у меня есть тензор и набор индексов

class_energy = torch.rand(3, 10, 32, 32)
class_logits = torch.empty_like(class_energy)
idxs = [2, 3, 5, 7]

Я хочу получить доступ к элементам с этими индексами в определенном измерении , чтобы я мог выполнить log_softmax.

Если бы я знал dim a-priori, то я мог бы просто использовать причудливый синтаксис __getitem__ / __setitem__: например, если dim=1, то class_energy[:, idxs]. Аналогично, если dim=2 -> class_energy[:, :, idxs], dim=0 -> class_energy[idxs] и т. Д. *

В случае, когда dim=1, я по сути хочу это:

class_logits[:, idxs] = F.log_softmax(class_energy[:, idxs], dim=1)

К сожалению, я не знаю значение dim раньше времени. Конечно, я могу заранее составить модный индекс с помощью:

fancy_index = tuple([slice(None)] * dim + [idxs])
class_logits[fancy_index] = F.log_softmax(class_energy[fancy_index], dim=dim)

Однако мне интересно, есть ли лучший способ сделать это. Для случая __getitem__ я точно знаю, что есть. Следующий код с использованием torch.index_select эквивалентен

fancy_index = tuple([slice(None)] * dim + [idxs])
index = torch.LongTensor(idxs).to(class_energy.device)
class_logits[fancy_index] = F.log_softmax(torch.index_select(class_energy, dim=dim, index=index, dim=dim))

Мало того, что он функционально одинаков, index_select намного быстрее (я видел улучшение в 2 раза), чем при использовании необычного синтаксиса getitem.

Моя проблема в том, что я не могу показаться схожей функциональности для части __setitem__ кода. Было бы очень хорошо, если бы я мог избавиться от причудливого индекса все вместе. Я изучил Tensor.put_, torch.index_put и torch.select, но, похоже, ни один из них не обладает нужной мне функциональностью. Я что-то упускаю или модный индекс - единственный способ решить эту проблему в настоящее время?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...