? Загрузить модель пироха с 0.4.1 до 0.4.0? - PullRequest
0 голосов
/ 08 декабря 2018

Я обучил модель DENSENET161, используя pytorch 0.4.1 (GPU), и в тестовой среде я должен загрузить ее в версии pytorch 0.4.0 (CPU).Я уже использую model.cpu(), но при загрузке статического словаря model.load_state_dict(checkpoint['state_dict'])

я получаю следующую ошибку:

RuntimeError: Ошибка (и) при загрузке state_dict для DenseNet:Неожиданный ключ (ы) в state_dict: "features.norm0.num_batches_tracked", "features.denseblock1.denselayer1.norm1.num_batches_tracked", "features.denseblock1.denselayer1.norm2.num_batches_tracked", "features.denseblock1.ches_track_1.jpg", ...

1 Ответ

0 голосов
/ 08 декабря 2018

Кажется, это связано с различием в реализации слоев нормализации между PyTorch 0.4.1 и 0.4 - первый отслеживает некоторую переменную состояния, называемую num_batches_tracked, которую Pytorch 0.4 не ожидает.Предполагая, что есть только неожиданные ключи и нет отсутствующих ключей (что я не могу сказать наверняка, так как вы вырезали сообщение об ошибке), вы можете просто удалить посторонние и, надеюсь, модель загрузится.Поэтому попробуйте

model_dict = checkpoint['state_dict']
filtered = {
    k: v for k, v in model_dict.items() if 'num_batches_tracked' not in k
}
model.load_state_dict(filtered)

. Обратите внимание, что во внутренних структурах нормализации могут быть изменения, отличные от того, что вы видите здесь, поэтому даже если это исправление подавляет исключение, модель все равно может вести себя некорректно.

...