Как выбрать все значения многомерного тензора на основе маски с помощью PyTorch? - PullRequest
0 голосов
/ 28 марта 2020

У меня есть тензор pred, который имеет .size() [4, 53, 161]. У меня есть другой тензор mask, который имеет .size() [4, 53].

. mask - это только 0 и 1. Я хочу получить значения pred, где mask имеет значение 1. Вы заметите, что pred имеет еще одно измерение, чем mask. Как я могу выполнить sh это?

1 Ответ

1 голос
/ 28 марта 2020
mask = mask.unsqueeze(2)
new_pred = pred * mask

Это добавит дополнительное измерение. Это будет сейчас [4, 53, 1]. Остальные позаботятся о трансляции. (если вы делаете какую-то операцию)

Предположим, у вас есть тензор формы [4, 53, 164], и теперь вы хотите уменьшить его до [4, 53], тогда вы можете применять арифметические c операции, подобные этой new_pred.mean(2).

...