В pytorch, как вы можете взять предварительно обученные веса в формате orderdict и превратить его обратно в модель? - PullRequest
0 голосов
/ 29 ноября 2018

Я скачал предварительно обученную модель в формате pth.tar, и после того, как я позвонил ей torch.load(), я вижу, что это формат orderdict с соответствующими названиями слоев и их весами.Затем я попытался просто зациклить ключи и значения с помощью словаря, чтобы создать nn.ParameterDict(), но из-за соглашений об именах в ключах для orderdict, например, layer.1.weights, я получил бы эту ошибку

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-91-578f5b12a2c6> in <module>()
----> 1 nn.ParameterDict(checkpoint)

D:\Anaconda\lib\site-packages\torch\nn\modules\container.py in __init__(self, parameters)
    425         super(ParameterDict, self).__init__()
    426         if parameters is not None:
--> 427             self.update(parameters)
    428 
    429     def __getitem__(self, key):

D:\Anaconda\lib\site-packages\torch\nn\modules\container.py in update(self, parameters)
    492             if isinstance(parameters, OrderedDict):
    493                 for key, parameter in parameters.items():
--> 494                     self[key] = parameter
    495             else:
    496                 for key, parameter in sorted(parameters.items()):

D:\Anaconda\lib\site-packages\torch\nn\modules\container.py in __setitem__(self, key, parameter)
    431 
    432     def __setitem__(self, key, parameter):
--> 433         self.register_parameter(key, parameter)
    434 
    435     def __delitem__(self, key):

D:\Anaconda\lib\site-packages\torch\nn\modules\module.py in register_parameter(self, name, param)
    136                             "Got {}".format(torch.typename(name)))
    137         elif '.' in name:
--> 138             raise KeyError("parameter name can't contain \".\"")
    139         elif name == '':
    140             raise KeyError("parameter name can't be empty string \"\"")

KeyError: 'parameter name can\'t contain "."'

Итак, в конце этого запуска я должен переименовать слои с layer.1.weights на что-то вроде layer_1_weights во время преобразования nn.ParameterDict?Это имеет значение?Я также погуглил и посмотрел на load_state_dict, и, исходя из моего базового понимания, вам нужно предварительно определить класс модели, а затем загрузить в него веса, ну, в данном случае у меня нет класса модели, япытаясь построить класс модели с информацией из этого orderdict файла.Так каков правильный подход к этому?

...