Если ошибка, которую вы получаете, это:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: one_hot is only applicable to index tensor.
Может быть, вам просто нужно конвертировать в int64
:
import torch
# random Tensor with the shape you said
indices = torch.Tensor(1, 1, 128, 128, 128).random_(1, 24)
# indices.shape => torch.Size([1, 1, 128, 128, 128])
# indices.dtype => torch.float32
n = 24
one_hot = torch.nn.functional.one_hot(indices.to(torch.int64), n)
# one_hot.shape => torch.Size([1, 1, 128, 128, 128, 24])
# one_hot.dtype => torch.int64
Вы также можете использовать indices.long()
.