Я скачал предварительно обученную модель в формате pth.tar, и после того, как я позвонил ей torch.load()
, я вижу, что это формат orderdict
с соответствующими названиями слоев и их весами.Затем я попытался просто зациклить ключи и значения с помощью словаря, чтобы создать nn.ParameterDict()
, но из-за соглашений об именах в ключах для orderdict
, например, layer.1.weights
, я получил бы эту ошибку
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-91-578f5b12a2c6> in <module>()
----> 1 nn.ParameterDict(checkpoint)
D:\Anaconda\lib\site-packages\torch\nn\modules\container.py in __init__(self, parameters)
425 super(ParameterDict, self).__init__()
426 if parameters is not None:
--> 427 self.update(parameters)
428
429 def __getitem__(self, key):
D:\Anaconda\lib\site-packages\torch\nn\modules\container.py in update(self, parameters)
492 if isinstance(parameters, OrderedDict):
493 for key, parameter in parameters.items():
--> 494 self[key] = parameter
495 else:
496 for key, parameter in sorted(parameters.items()):
D:\Anaconda\lib\site-packages\torch\nn\modules\container.py in __setitem__(self, key, parameter)
431
432 def __setitem__(self, key, parameter):
--> 433 self.register_parameter(key, parameter)
434
435 def __delitem__(self, key):
D:\Anaconda\lib\site-packages\torch\nn\modules\module.py in register_parameter(self, name, param)
136 "Got {}".format(torch.typename(name)))
137 elif '.' in name:
--> 138 raise KeyError("parameter name can't contain \".\"")
139 elif name == '':
140 raise KeyError("parameter name can't be empty string \"\"")
KeyError: 'parameter name can\'t contain "."'
Итак, в конце этого запуска я должен переименовать слои с layer.1.weights
на что-то вроде layer_1_weights
во время преобразования nn.ParameterDict
?Это имеет значение?Я также погуглил и посмотрел на load_state_dict
, и, исходя из моего базового понимания, вам нужно предварительно определить класс модели, а затем загрузить в него веса, ну, в данном случае у меня нет класса модели, япытаясь построить класс модели с информацией из этого orderdict
файла.Так каков правильный подход к этому?