Есть ли способ скопировать все параметры одной модели Pytorch в другую, особенно средние значения нормализации партии и стандартного ввода? - PullRequest
0 голосов
/ 01 декабря 2018

В Интернете я нашел много правильных способов скопировать один параметр модели Pytorch в другой, но каким-то образом операция копирования-вставки всегда пропускает параметры нормализации пакета.Все работает нормально, пока в моей модели используются только такие модули, как conv2d, linear, dropout, max pool и т. Д.Но как только я добавляю нормализацию партии в модели pytorch, нижеприведенный скрипт перестает работать и точность во время теста отличается:

net = model()
copy_net = model()

for param in net.module.parameters():
    copy_param.append(param.clone().detach())

count = 0
for param in copy_net.module.parameters():
    param.data =  copy_param[count]
    param.requires_grad = False
    count = count +1

Кто-нибудь может дать мне возможное решение для копирования нормализации партии также?

1 Ответ

0 голосов
/ 04 апреля 2019

net.load_state_dict(copy_net.state_dict()) должно работать.

Согласно @dxtx, в философии pytorch, dict состояния должен охватывать все состояния в «модуле», например, в модуле пакетной нормы, скользящем среднем и переменной,если я правильно запомнил, должна быть часть государственного диктата.Но на самом деле, если вы написали модуль, подобный пакетной норме, вам придется переопределить метод 'state_dict'.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...