Вот решение, которое вы можете найти
ids = ids.repeat(1, 255).view(-1, 1, 255)
Пример, приведенный ниже:
x = torch.arange(24).view(4, 3, 2)
"""
tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]])
"""
ids = torch.randint(0, 3, size=(4, 1))
"""
tensor([[0],
[2],
[0],
[2]])
"""
idx = ids.repeat(1, 2).view(4, 1, 2)
"""
tensor([[[0, 0]],
[[2, 2]],
[[0, 0]],
[[2, 2]]])
"""
torch.gather(x, 1, idx)
"""
tensor([[[ 0, 1]],
[[10, 11]],
[[12, 13]],
[[22, 23]]])
"""