Поэтапное индексирование элементов в PyTorch для C ++ - PullRequest
0 голосов
/ 27 ноября 2018

Я использую интерфейс C ++ для PyTorch и борюсь с относительно простой проблемой индексации.

У меня есть тензор 8 от 6, такой как приведенный ниже:

[ Variable[CUDAFloatType]{8,6} ] 
                 0           1           2           3           4           5
0       1.7107e-14  4.0448e-17  4.9708e-06  1.1664e-08  9.9999e-01  2.1857e-20
1       1.8288e-14  5.9356e-17  5.3042e-06  1.2369e-08  9.9999e-01  2.4799e-20
2       2.6828e-04  9.0390e-18  1.7517e-02  1.0529e-03  9.8116e-01  6.7854e-26
3       5.7521e-10  3.1037e-11  1.5021e-03  1.2304e-06  9.9850e-01  1.4888e-17
4       1.7811e-13  1.8383e-15  1.6733e-05  3.8466e-08  9.9998e-01  5.2815e-20
5       9.6191e-06  2.6217e-23  3.1345e-02  2.3024e-04  9.6842e-01  2.9435e-34
6       2.2653e-04  8.4642e-18  1.6085e-02  9.7405e-04  9.8271e-01  6.3059e-26
7       3.8951e-14  2.9903e-16  8.3518e-06  1.7974e-08  9.9999e-01  3.6993e-20

У меня есть другой Тензор с только 8 элементами, такими как:

[ Variable[CUDALongType]{8} ] 
 0
 3
 4
 4
 4
 4
 4
 4

Я хотел бы проиндексировать строки моего первого тензора, используя второй для получения:

        0           
0       1.7107e-14  
1       1.2369e-08
2       9.8116e-01  
3       9.9850e-01  
4       9.9998e-01
5       9.6842e-01  
6       9.8271e-01  
7       9.9999e-01

Я пробовал несколько разных подходов, включая index_select, но, похоже, он выдает выходные данные, которые имеют те же размеры, что и входные (8x6).

В Python, я думаю, я мог бы индексировать с помощью встроенной индексации Python, как обсуждалось здесь: https://github.com/pytorch/pytorch/issues/1080

К сожалению, в C ++ я могу индексировать только тензор со скаляром (нульмерный тензор), поэтому я могуне думаю, что такой подход мне подходит здесь.

Как мне достичь желаемого результата, не прибегая к петлям?

1 Ответ

0 голосов
/ 29 ноября 2018

Оказывается, вы можете сделать это несколькими разными способами.Один с gather и один с index.Из обсуждений PyTorch , где я задал тот же вопрос:

Использование torch::gather

auto x = torch::randn({8, 6});
int64_t idx_data[8] = { 0, 3, 4, 4, 4, 4, 4, 4 };
auto idx = x.type().toScalarType(torch::kLong).tensorFromBlob(idx_data, 8);
auto result = x.gather(1, idx.unsqueeze(1));

Использование C ++ специфично torch::index

auto x = torch::randn({8, 6});
int64_t idx_data[8] = { 0, 3, 4, 4, 4, 4, 4, 4 };
auto idx = x.type().toScalarType(torch::kLong).tensorFromBlob(idx_data, 8);
auto rows = torch::arange(0, x.size(0), torch::kLong);
auto result = x.index({rows, idx});
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...