Загрузка модели с pytorch - PullRequest
0 голосов
/ 09 февраля 2020

У меня проблема с загрузкой моей модели в проекте Image Classifier. Сначала я сохранил его:

model.class_to_idx = train_data.class_to_idx

checkpoint = {'arch': 'vgg19',
              'learn_rate': learn_rate,
              'epochs': epochs,
              'state_dict': model.state_dict(),
              'class_to_idx': model.class_to_idx,
              'optimizer': optimizer.state_dict(),
              'input_size': 25088,
              'output_size': 102,
              'momentum': momentum,
              'batch_size':64,
              'classifier' : classifier
             }

torch.save(checkpoint, 'checkpoint.pth')

Затем я попытался загрузить сохраненный проект:

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)

    learn_rate = checkpoint['learn_rate']

    optimizer.load_state_dict(checkpoint['optimizer'])

    model = models.vgg16(pretrained=True)
    model.epochs = checkpoint['epochs']
    model.load_state_dict(checkpoint['state_dict'])
    model.class_to_idx = checkpoint['class_to_idx']
    model.classifier = checkpoint['classifier']

    return learn_rate, optimizer, model

learn_rate, optimizer, model = load_checkpoint('checkpoint.pth')

И при попытке загрузить я получаю сообщение об ошибке:

<ipython-input-75-5bd1aa042c7f> in load_checkpoint(filepath)
      9     model = models.vgg16(pretrained=True)
     10     model.epochs = checkpoint['epochs']
---> 11     model.load_state_dict(checkpoint['state_dict'])
     12     model.class_to_idx = checkpoint['class_to_idx']
     13     model.classifier = checkpoint['classifier']

RuntimeError: Error(s) in loading state_dict for VGG:
    Missing key(s) in state_dict: "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias", "classifier.6.weight", "classifier.6.bias". 
    Unexpected key(s) in state_dict: "classifier.fc1.weight", "classifier.fc1.bias", "classifier.fc2.weight", "classifier.fc2.bias". 

Кажется, это проблема классификатора. Кто-нибудь знает, что происходит?

1 Ответ

1 голос
/ 11 февраля 2020

Комментарии Jodag в основе проблемы. Если fc1 fc2 соответствуют classifier.0 classifier.3, classifier.6, вы можете настроить словарь, чтобы связать их. При загрузке весов в модель убедитесь, что вы добавили опцию strict = False.

Вам потребуется переобучить вашу модель для классификатора, потому что в вашем состоянии нет весов для 3 слоев, но есть 2 неиспользованных веса для слоев - но оно должно сходиться очень быстро (из личного опыта).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...