Ключи словаря PyTorch не совпадают - PullRequest
1 голос
/ 14 июля 2020

Я пытаюсь реализовать сверточный LSTM, который нашел в Интернете, и кажется, что ключи словаря не совпадают:

Предварительно обученные веса находятся в маринованном словаре со следующими ключами:

pkl_load = torch.load(trained_model_dir)
print(pkl_load.keys())

odict_keys(['module.E.conv1.weight', 'module.E.bn1.weight', 'module.E.bn1.bias', ....

Однако ключи в state_dict для фактической модели NN:

"E.conv1.weight", "E.bn1.weight", "E.bn1.bias", ....

Я получаю сообщение об ошибке при попытке загрузить предварительно обученные веса в state_dict, потому что ключи не не совпадают. Как можно обойти это? (Извините, если это легко, я новичок в PyTorch).

1 Ответ

0 голосов
/ 14 июля 2020

Вы можете сделать что-то вроде:

keys = ['module.E.conv1.weight', 'module.E.bn1.weight', 'module.E.bn1.bias']
res = []
for key in keys:
    words = key.split('.')
    tempRes = words[1:]
    newWord = '.'.join(tempRes)
    res.append(newWord)
print(res)

вывод:

['E.conv1.weight', 'E.bn1.weight', 'E.bn1.bias']
...