PyTorch переводит индекс в 1-0 - PullRequest
1 голос
/ 04 апреля 2020

Как быстро установить элементы в указанном списке индексов на 1, а другие на 0?

Например, у меня есть пул идентификаторов, например: torch.arange(10), для данного входного индекса tensor([1,5,7,9,2]) хочу возврат tensor([0,1,1,0,0,1,0,1,0,1])

1 Ответ

1 голос
/ 04 апреля 2020

Проще всего начать с zeros и заполнить ones, используя необычную индексацию, например:

import torch

tensor = torch.zeros(10)
tensor[[1, 5, 7, 9, 2]] = 1

Если ваши идентификаторы предопределены (например, torch.arange(10)), и вы хотите получить только те элементы, которые не zero, вы можете сделать это:

import torch

ids = torch.arange(10)

mask = torch.zeros_like(ids).bool() # it has to be bool
mask[[1, 5, 7, 9, 2]] = True

torch.masked_select(ids, mask)

Что даст вам:

tensor([1, 2, 5, 7, 9])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...