Как я могу обрезать тензор на основе маски с помощью PyTorch? - PullRequest
0 голосов
/ 27 марта 2020

У меня есть тензор inp, который имеет размер: torch.Size([4, 122, 161]).

У меня также есть mask с размером: torch.Size([4, 122]).

Каждый элемент в моем mask выглядит примерно так:

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0', grad_fn=<SelectBackward>)

Так что я хочу обрезать inp, чтобы уменьшить его по измерению = 1, чтобы он существовал только там, где mask имеет 1. В показанном случае есть 23 1 с, поэтому я хочу, чтобы размер inp был: torch.Size([4, 23, 161])

1 Ответ

1 голос
/ 27 марта 2020

Я думаю, что расширенная индексация будет работать. (Я предполагаю, что каждая маска имеет равные 23 1 с)

inp_trimmed = inp[mask.type(torch.bool)].reshape(4,23,161)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...