Некоторые проблемы при загрузке весов в Pytorch - PullRequest
0 голосов
/ 25 октября 2019

Я искал много ресурсов для решения этой проблемы, но все еще застрял здесь.

Я следовал руководству по pytorch и сохранил параметры с помощью

torch.save(the_model.state_dict(), PATH)

, затем загрузил параметры с помощью

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

После загрузки параметров я напечатал свою модель. Вышла ошибка.

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

Я обнаружил, что некоторые люди сталкиваются с той же проблемой, но кажется, что ее можно игнорировать ?! Затем я пытаюсь ввести данные в модель с помощью `` model.forward () '' ', появилась другая ошибка.

AttributeError: 'IncompatibleKeys' object has no attribute 'forward'

Я знаю, что этот тип метода сохранения (the_model.state_dict()) простосохранить "веса". Его следует использовать только .eval() из-за сохранения важной информации (выпадение, групповой вызов и т. Д.). Поэтому я пытаюсь model.eval(), он все еще имеет ту же ошибку.

AttributeError: 'IncompatibleKeys' object has no attribute 'eval'

Вот некоторый относительный код:

  • Инициализация модели:

    model = VAE(some constructor parameters)

  • После обучения:

    checkpoint_path = os.path.join(save_path, "E%02d.pkl" % ep)
    torch.save(model.state_dict(), checkpoint_path)
    
  • Инициализация той же модели и загрузка параметров в модель:

    model = VAE(some constructor parameters)
    checkpoint = torch.load("E24.pkl", map_location='cuda:0')
    model = model.load_state_dict(checkpoint)
    

Я больше не буду тренировать эту модель. Я просто хочу загрузить параметры, а затем проверить производительность. Спасибо за чтение. :)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...