Pytorch RuntimeError: неверный индекс в сборе - PullRequest
0 голосов
/ 26 января 2019

Я новичок в Pytorch, и я сталкиваюсь с этой ошибкой:

x.gather (1, c)

RuntimeError: Неверный индекс при сборе в / pytorch / aten /src / TH / generic / THTensorEvenMoreMath.cpp: 457

Вот некоторые сведения о тензорах:

print(x.size())
print(c.size())
print(type(x))
print(type(c))

torch.Size([128, 2])
torch.Size([128, 1])
<class 'torch.Tensor'>
<class 'torch.Tensor'>

x заполнено значениями с плавающей запятой и c целыми числами, может ли этобыть проблемой?

1 Ответ

0 голосов
/ 27 января 2019

Это просто означает, что ваш тензор индекса c имеет недопустимые индексы.Например, следующий тензор индекса действителен:

        x = torch.tensor([
        [5, 9, 1],
        [3, 2, 8],
        [7, 4, 0]
    ])
    c = torch.tensor([
        [0, 0, 0],
        [1, 2, 0],
        [2, 2, 1]
    ])
    x.gather(1, c)
>>>tensor([[5, 5, 5],
        [2, 8, 3],
        [0, 0, 4]])

Однако следующий тензор индекса недействителен:

c = torch.tensor([
    [0, 0, 0],
    [1, 2, 0],
    [2, 2, 3]
])

И он дает исключение, которое вы упомянули

RuntimeError: Неверный индекс в сборе

...