Pytorch-lightning - не работает загрузка модели с КПП - PullRequest
0 голосов
/ 04 августа 2020

Я обучил vanilla vae, который я модифицировал из этого репозитория . Когда я пытаюсь использовать обученную модель, я не могу загрузить веса, используя load_from_checkpoint. Кажется, есть несоответствие между моим объектом контрольной точки и моим объектом lightningModule.

Я установил эксперимент (VAEXperiment), используя pytorch-lightning LightningModule. Пытаюсь загрузить веса в сеть с помощью:

#building a new model
model = VanillaVAE(**config['model_params'])
model.build_layers()

#loading the weights
experiment = VAEXperiment(model, config['exp_params'])
experiment.load_from_checkpoint(path_to_checkpoint, config['exp_params'])

Еще пробовал:

checkpoint = torch.load(path_to_checkpoint, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])

Но получаю ошибку Unexpected key(s) in state_dict: "model.encoder.0.0.weight", "model.encoder.0.0.bias" ...

Я также следил за этой проблемой на https://github.com/PyTorchLightning/pytorch-lightning/issues/924 https://github.com/PyTorchLightning/pytorch-lightning/issues/2798

Почему я получаю эту ошибку? Это из-за модулей кодировщика и декодера в моей модели? Судя по журналу регистрации проблем git, кажется, что ошибка устранена. Что я делаю не так?

1 Ответ

0 голосов
/ 04 августа 2020

Размещение ответа из комментариев:

experiment.load_state_dict(checkpoint['state_dict'])
...