Предположим, у меня есть тензор маски со всеми нулями, подобный этому:
mask = torch.zeros(5,3, dtype=torch.bool)
Теперь я хочу установить значение mask
на пересечении следующих индексов rows
и cols
True
:
rows = torch.tensor([0,2,4])
cols = torch.tensor([1,2])
Я хотел бы получить следующий результат:
tensor([[False, True, True ],
[False, False, False],
[False, True, True ],
[False, False, False],
[False, True, True ]])
Когда я пытаюсь использовать следующий код, я получаю сообщение об ошибке:
mask[rows, cols] = True
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [2]
Как я могу сделать это эффективно в PyTorch?