Как загрузить файл контрольной точки в модели Pytorch? - PullRequest
0 голосов
/ 13 февраля 2019

В моей модели pytorch я инициализирую свою модель и оптимизатор следующим образом.

model = MyModelClass(config, shape, x_tr_mean, x_tr,std)
optimizer = optim.SGD(model.parameters(), lr=config.learning_rate)

А вот и путь к моему файлу контрольных точек.

checkpoint_file = os.path.join (config.save_dir, "checkpoint.pth")

Чтобы загрузить этот файл контрольных точек, я проверяю и вижу, существует ли файл контрольных точек, а затем загружаю егоа также модель и оптимизатор.

if os.path.exists(checkpoint_file):
    if config.resume:
        torch.load(checkpoint_file)
        model.load_state_dict(torch.load(checkpoint_file))
        optimizer.load_state_dict(torch.load(checkpoint_file))

Также вот как я сохраняю свою модель и оптимизатор.

 torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'iter_idx': iter_idx, 'best_va_acc': best_va_acc}, checkpoint_file)

По какой-то причине я продолжаю получать странную ошибку при каждом запуске этого кода.

model.load_state_dict(torch.load(checkpoint_file))
File "/home/Josh/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for MyModelClass:
        Missing key(s) in state_dict: "mean", "std", "attribute.weight", "attribute.bias".
        Unexpected key(s) in state_dict: "model", "optimizer", "iter_idx", "best_va_acc"

Кто-нибудь знает, почему я получаю эту ошибку?

1 Ответ

0 голосов
/ 13 февраля 2019

Вы сохранили параметры модели в словаре.Вы должны использовать ключи, которые вы использовали при сохранении ранее, чтобы загрузить контрольную точку модели и state_dict s, как это:

if os.path.exists(checkpoint_file):
    if config.resume:
        checkpoint = torch.load(checkpoint_file)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])

Вы можете проверить официальный учебник на сайте PyTorch для получения дополнительной информации.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...