PyTorch - AttributeError: объект 'bool' не имеет атрибута 'sum' - PullRequest
0 голосов
/ 17 апреля 2020

Я использую модель глубокого обучения с использованием PyTorch и получаю следующую ошибку.

'correct+=(yhat==y_test).sum().int()'

AttributeError: у объекта 'bool' нет атрибута 'sum'

Ниже приведен большой фрагмент кода.

'' '

for x_test, y_test in validation_loader:
            model.eval()
            z = model(x_test)
            yhat = torch.max(z.data,1)
            correct+=(yhat==y_test).sum().int()
            accuracy = correct / n_test
            accuracy_list.append(accuracy)

'' '

1 Ответ

0 голосов
/ 17 апреля 2020

Я могу ошибаться, но я думаю, что в этой строке

yhat = torch.max(z.data,1)

вы пытаетесь получить torch.argmax(). Я предполагаю, что вы пытаетесь получить прогнозы в формате [0, 1, 0, 1, 1], а не максимальную вероятность вашей партии.

...