Я обучил vanilla vae, который я модифицировал из этого репозитория . Когда я пытаюсь использовать обученную модель, я не могу загрузить веса, используя load_from_checkpoint
. Кажется, есть несоответствие между моим объектом контрольной точки и моим объектом lightningModule
.
Я установил эксперимент (VAEXperiment
), используя pytorch-lightning LightningModule
. Пытаюсь загрузить веса в сеть с помощью:
#building a new model
model = VanillaVAE(**config['model_params'])
model.build_layers()
#loading the weights
experiment = VAEXperiment(model, config['exp_params'])
experiment.load_from_checkpoint(path_to_checkpoint, config['exp_params'])
Еще пробовал:
checkpoint = torch.load(path_to_checkpoint, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])
Но получаю ошибку Unexpected key(s) in state_dict: "model.encoder.0.0.weight", "model.encoder.0.0.bias"
...
Я также следил за этой проблемой на https://github.com/PyTorchLightning/pytorch-lightning/issues/924 https://github.com/PyTorchLightning/pytorch-lightning/issues/2798
Почему я получаю эту ошибку? Это из-за модулей кодировщика и декодера в моей модели? Судя по журналу регистрации проблем git, кажется, что ошибка устранена. Что я делаю не так?