Ошибка загрузки неожиданного ключа модели "Pytourch 3.0" module.features.0.weight "в state_dict - PullRequest
0 голосов
/ 23 октября 2018

Я пытаюсь загрузить модель, которую я обучил, используя Pytorch, но продолжаю получать следующую ошибку:

Файл "convert.py", строка 12, в model.load_state_dict (torch.load ('model / model_vgg2d_2.pth')) Файл "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", строка 490, в load_state_dict .format (name))KeyError: 'неожиданный ключ' module.features.0.weight 'в state_dict'

Ниже приведен мой код:

import torch.onnx
import torch.nn as nn

class TempModel(nn.Module):
    def __init__(self):
        super(TempModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 5, (3, 3))
    def forward(self, inp):
        return self.conv1(inp)

model = nn.DataParallel(TempModel())
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(model, dummy_input, "model_onnx/model_vgg2d_0.onnx")

Я работаю на той же машине, что и у меняиспользуется для обучения модели (которая имеет несколько графических процессоров).Есть идеи, что я делаю не так?

1 Ответ

0 голосов
/ 23 октября 2018

При загрузке state_dict вам нужно, чтобы он был state_dict той же той же модели : вы не можете загрузить state_dict модели VGG в совершенно другую BasicModel.


старый ответ
Вы сохранили модель без nn.DataParallel, примененной к модели, и теперь вы пытаетесь загрузить ее после добавления.Попробуйте

model = TempModel()
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
model = nn.DataParallel(model)  # parallel AFTER load
...