Ошибка сохранения модели PyTorch: «Не удается обработать локальный объект» - PullRequest
1 голос
/ 29 мая 2020

Когда я пытаюсь сохранить модель PyTorch с помощью этого:

checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

, я сталкиваюсь с этой проблемой:

    E:\PROGRAM FILES\Anaconda\envs\staj_projesi\lib\site-packages\torch\serialization.py:251: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.
...

      "type " + obj.__name__ + ". It won't be checked "
    Can't pickle local object 'trainModel.<locals>.Net'

Когда я пытаюсь сохранить модель PyTorch с помощью этого:

checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

У меня проблем нет, но я хочу сохранить класс ИНС. Как я могу решить эту проблему? Кроме того, я мог сохранить модель с первой структурой в других проектах до

1 Ответ

0 голосов
/ 29 мая 2020

Вы не можете! torch.save сохраняет только объекты state_dict().

Когда вы используете следующее:

checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

Вы пытаетесь сохранить саму модель, но эти данные сохраняются в model.state_dict() и при загрузке модели с state_dict вы должны сначала инициировать объект модели.

Именно по этой причине второй метод работает правильно:

checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

Я бы посоветовал прочитать документы pytorch о том, как правильно сохранить \ загрузить модель по следующей ссылке: https://pytorch.org/tutorials/beginner/saving_loading_models.html

...