Выбор / Фильтр по индексам с использованием pytorch - PullRequest
0 голосов
/ 03 января 2019

У меня есть 3-мерный массив NumPy, например:

 x = [[[0.3, 0.2, 0.5],
       [0.1, 0.2, 0.7],
       [0.2, 0.2, 0.6]]]

Массив индексов также 3-мерный, как:

indices = [[[0],
            [1],
            [2]]]

Я ожидаю, что результат будет:

 output= [[[0.3],
           [0.2],
           [0.6]]]

Я попробовал функции torch.index_select и torch.gather, но не смог найти правильный способ справиться с измерением.Спасибо за любую помощь!

Ответы [ 2 ]

0 голосов
/ 03 января 2019

Как насчет использования x.gather(dim=2, indices)?Это подходит для меня.

0 голосов
/ 03 января 2019

Я нашел ответ.Пожалуйста, дайте мне знать, если есть лучшее решение.

torch.cat([torch.index_select(a.view(1, -1), 1, i.view(1, -1)[0]) 
                                         for a, i in zip(x, indices)])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...