У меня есть 3-мерный массив NumPy, например:
x = [[[0.3, 0.2, 0.5],
[0.1, 0.2, 0.7],
[0.2, 0.2, 0.6]]]
Массив индексов также 3-мерный, как:
indices = [[[0],
[1],
[2]]]
Я ожидаю, что результат будет:
output= [[[0.3],
[0.2],
[0.6]]]
Я попробовал функции torch.index_select и torch.gather, но не смог найти правильный способ справиться с измерением.Спасибо за любую помощь!