pytorch как выбрать каналы по маске? - PullRequest
0 голосов
/ 29 мая 2019

Я хочу знать, как выбрать каналы по маске в Pytorch. [канал1 канал2 канал3 канал4] x [1,0,0,1] -> [канал1, канал4] Я попытался torch.masked_select (), и это не сработало.

если вход имеет форму, подобную [B,C,H,W], форма вывода должна быть [B,masked_C,H,W],

import torch
from torch import nn
input = torch.randn((1,5,3,3))
pool = nn.AdaptiveAvgPool2d(1)
w = torch.sigmoid(pool(input)).view(1,-1)
mask = torch.gt(w,0.5)
print(input)
print(w)
print(mask)

вывод выглядит следующим образом:

tensor([[[[ 0.9129, -0.9763,  1.4460],
          [ 0.3608,  0.5561, -1.4612],
          [ 1.4953, -1.2474,  0.4069]],

         [[-0.9121,  0.1261,  0.4661],
          [-1.1624, -1.0266, -1.5419],
          [ 1.0644,  1.0039, -0.4022]],

         [[-1.8454, -0.2150,  2.3703],
          [ 0.5224,  0.3366,  1.7545],
          [-0.4624,  1.2639,  1.8032]],

         [[-1.1558, -1.9985, -1.1336],
          [-0.4400, -0.2092,  0.0677],
          [-0.4172, -0.3614, -1.3193]],

         [[-0.9441, -0.2944,  0.3381],
          [ 1.6562, -0.5623,  0.0599],
          [ 0.7229,  0.0472, -0.5122]]]])
tensor([[0.5414, 0.4341, 0.6489, 0.3156, 0.5142]])
tensor([[1, 0, 1, 0, 1]], dtype=torch.uint8)

результат, который я хочу получить, выглядит следующим образом:

tensor([[[[ 0.9129, -0.9763,  1.4460],
          [ 0.3608,  0.5561, -1.4612],
          [ 1.4953, -1.2474,  0.4069]],

         [[-1.8454, -0.2150,  2.3703],
          [ 0.5224,  0.3366,  1.7545],
          [-0.4624,  1.2639,  1.8032]],

         [[-0.9441, -0.2944,  0.3381],
          [ 1.6562, -0.5623,  0.0599],
          [ 0.7229,  0.0472, -0.5122]]]])

1 Ответ

0 голосов
/ 29 мая 2019

Я считаю, что вы можете просто сделать:

input[mask]

Btw. Вам не нужно звонить sigmoid, а затем .gt(0.5). Вы можете напрямую сделать .gt(0.0) без вызова сигмовидной кишки.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...