Как изменить значения 2-мерного тензора в определенных строках и столбцах - PullRequest
1 голос
/ 09 апреля 2020

Предположим, у меня есть тензор маски со всеми нулями, подобный этому:

mask = torch.zeros(5,3, dtype=torch.bool)

Теперь я хочу установить значение mask на пересечении следующих индексов rows и cols True:

rows = torch.tensor([0,2,4]) 
cols = torch.tensor([1,2])

Я хотел бы получить следующий результат:

tensor([[False, True,  True ],
        [False, False, False],
        [False, True,  True ],
        [False, False, False],
        [False, True,  True ]])

Когда я пытаюсь использовать следующий код, я получаю сообщение об ошибке:

mask[rows, cols] = True

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [2]

Как я могу сделать это эффективно в PyTorch?

1 Ответ

1 голос
/ 09 апреля 2020

Вам нужна правильная форма, для которой вы можете использовать torch.unsqueeze

mask = torch.zeros(5,3, dtype=torch.bool)
mask[rows, cols.unsqueeze(1)] = True
mask
tensor([[False,  True,  True],
        [False, False, False],
        [False,  True,  True],
        [False, False, False],
        [False,  True,  True]])

или torch.reshape

mask[rows, cols.reshape(-1,1)] = True
mask
tensor([[False,  True,  True],
        [False, False, False],
        [False,  True,  True],
        [False, False, False],
        [False,  True,  True]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...