У меня проблема с загрузкой моей модели в проекте Image Classifier. Сначала я сохранил его:
model.class_to_idx = train_data.class_to_idx
checkpoint = {'arch': 'vgg19',
'learn_rate': learn_rate,
'epochs': epochs,
'state_dict': model.state_dict(),
'class_to_idx': model.class_to_idx,
'optimizer': optimizer.state_dict(),
'input_size': 25088,
'output_size': 102,
'momentum': momentum,
'batch_size':64,
'classifier' : classifier
}
torch.save(checkpoint, 'checkpoint.pth')
Затем я попытался загрузить сохраненный проект:
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
learn_rate = checkpoint['learn_rate']
optimizer.load_state_dict(checkpoint['optimizer'])
model = models.vgg16(pretrained=True)
model.epochs = checkpoint['epochs']
model.load_state_dict(checkpoint['state_dict'])
model.class_to_idx = checkpoint['class_to_idx']
model.classifier = checkpoint['classifier']
return learn_rate, optimizer, model
learn_rate, optimizer, model = load_checkpoint('checkpoint.pth')
И при попытке загрузить я получаю сообщение об ошибке:
<ipython-input-75-5bd1aa042c7f> in load_checkpoint(filepath)
9 model = models.vgg16(pretrained=True)
10 model.epochs = checkpoint['epochs']
---> 11 model.load_state_dict(checkpoint['state_dict'])
12 model.class_to_idx = checkpoint['class_to_idx']
13 model.classifier = checkpoint['classifier']
RuntimeError: Error(s) in loading state_dict for VGG:
Missing key(s) in state_dict: "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias", "classifier.6.weight", "classifier.6.bias".
Unexpected key(s) in state_dict: "classifier.fc1.weight", "classifier.fc1.bias", "classifier.fc2.weight", "classifier.fc2.bias".
Кажется, это проблема классификатора. Кто-нибудь знает, что происходит?