Pytorch: AttributeError: у объекта 'function' нет атрибута 'copy' - PullRequest
0 голосов
/ 16 апреля 2020

Я пытаюсь загрузить модель state_dict Я тренировался на Google Colab GPU, вот мой код для загрузки модели:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = models.resnet50()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
model.load_state_dict(copy.deepcopy(torch.load("./models/model.pth",device)))
model = model.to(device)
model.eval()

Вот ошибка:

state_dict = state_dict.copy ()

AttributeError: у объекта 'function' нет атрибута 'copy'

Pytorch:

>>> import torch
>>> print (torch.__version__)
1.4.0
>>> import torchvision
>>> print (torchvision.__version__)
0.5.0

Пожалуйста, помогите мне искали везде безрезультатно

[полная информация об ошибке] [1] https://i.stack.imgur.com/s22DL.png

1 Ответ

2 голосов
/ 16 апреля 2020

Полагаю, это то, что вы сделали по ошибке. Вы сохранили функцию

torch.save(model.state_dict, 'model_state.pth')

вместо state_dict ()

torch.save(model.state_dict(), 'model_state.pth')

В противном случае все должно работать как положено. (Я тестировал следующий код на Colab)

Замените model.state_dict() на model.state_dict, чтобы воспроизвести ошибку

import copy
model = TheModelClass()
torch.save(model.state_dict(), 'model_state.pth')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.load_state_dict(copy.deepcopy(torch.load("model_state.pth",device)))
...