Я искал много ресурсов для решения этой проблемы, но все еще застрял здесь.
Я следовал руководству по pytorch и сохранил параметры с помощью
torch.save(the_model.state_dict(), PATH)
, затем загрузил параметры с помощью
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
После загрузки параметров я напечатал свою модель. Вышла ошибка.
IncompatibleKeys(missing_keys=[], unexpected_keys=[])
Я обнаружил, что некоторые люди сталкиваются с той же проблемой, но кажется, что ее можно игнорировать ?! Затем я пытаюсь ввести данные в модель с помощью `` model.forward () '' ', появилась другая ошибка.
AttributeError: 'IncompatibleKeys' object has no attribute 'forward'
Я знаю, что этот тип метода сохранения (the_model.state_dict()
) простосохранить "веса". Его следует использовать только .eval()
из-за сохранения важной информации (выпадение, групповой вызов и т. Д.). Поэтому я пытаюсь model.eval()
, он все еще имеет ту же ошибку.
AttributeError: 'IncompatibleKeys' object has no attribute 'eval'
Вот некоторый относительный код:
Инициализация модели:
model = VAE(some constructor parameters)
После обучения:
checkpoint_path = os.path.join(save_path, "E%02d.pkl" % ep)
torch.save(model.state_dict(), checkpoint_path)
Инициализация той же модели и загрузка параметров в модель:
model = VAE(some constructor parameters)
checkpoint = torch.load("E24.pkl", map_location='cuda:0')
model = model.load_state_dict(checkpoint)
Я больше не буду тренировать эту модель. Я просто хочу загрузить параметры, а затем проверить производительность. Спасибо за чтение. :)