Вы можете сделать так:
import torch
n = 10
batch_size = 2
in_flags = torch.randint(n, (batch_size,), dtype=torch.long)
x = torch.zeros((batch_size, n), dtype=torch.float)
# this is how you can do this
x[torch.arange(batch_size), in_flags] = 1.0
print(in_flags)
print(x)
Выход:
tensor([8, 0])
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])