Ошибка времени выполнения при загрузке state_dict для модели VGG16 в Pytorch - PullRequest
0 голосов
/ 15 апреля 2020

Я использовал этот следующий код для получения веса и смещения vgg16:

class vgg_features(nn.Module):
def __init__(self):
    super(vgg_features, self).__init__()
    # get vgg16 features up to conv 4_3
    self.model = nn.Sequential(*list(vgg16(pretrained=True).features)[:23])
    # print(self.model)
    # will not need to compute gradients
    for param in self.parameters():
        param.requires_grad=False

def forward(self, x, renormalize=True):
    # change normaliztion form [-1,1] to VGG normalization
    if renormalize:
        x = ((x*.5+.5)-torch.cuda.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))/torch.cuda.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)
    return self.model(x)

Однако я получаю эту ошибку:

RuntimeError: Ошибка (и) при загрузке state_dict для модели : Отсутствуют ключи в state_dict: "vgg.model.0.weight", "vgg.model.0.bias", "vgg.model.2.weight", "vgg.model.2.bias", " vgg.model.5.weight "," vgg.model.5.bias "," vgg.model.7.weight "," vgg.model.7.bias "," vgg.model.10.weight "," vgg.model.10.bias "," vgg.model.12.weight "," vgg.model.12.bias "," vgg.model.14.weight "," vgg.model.14.bias "," vgg.model.17.weight "," vgg.model.17.bias "," vgg.model.19.weight "," vgg.model.19.bias "," vgg.model.21.weight "," vgg.model.21.bias». Неожиданные ключи в state_dict: "vgg.features.0.weight", "vgg.features.0.bias", "vgg.features.2.weight", "vgg.features.2.bias", "vgg .features.5.weight "," vgg.features.5.bias "," vgg.features.7.weight "," vgg.features.7.bias "," vgg.features.10.weight "," vgg .features.10.bias "," vgg.features.12.weight "," vgg.features.12.bias "," vgg.features.14.weight "," vgg.features.14.bias "," vgg .features.17.weight "," vgg.features.17.bias "," vgg.features.19.weight "," vgg.features.19.bias "," vgg.features.21.weight "," vgg .features.21.bias».

строка, в которой я получаю сообщение об ошибке:

model.load_state_dict(torch.load(os.path.join(args.outroot, 
'%s_net.pth'%args.load)), strict=args.test)

Пожалуйста, помогите мне исправить это.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...