Добавление параметров в предварительно обученную модель - PullRequest
0 голосов
/ 13 мая 2019

В Pytorch мы загружаем предварительно обученную модель следующим образом:

net.load_state_dict(torch.load(path)['model_state_dict'])

Тогда структура сети и загруженная модель должны быть точно такими же. Однако можно ли загрузить веса, но затем изменить сеть / добавить дополнительный параметр?

Примечание: Если перед загрузкой весов мы добавим дополнительный параметр в модель, например,

self.parameter = Parameter(torch.ones(5),requires_grad=True) 

мы получим Missing key(s) in state_dict: ошибку при загрузке весов.

1 Ответ

0 голосов
/ 14 мая 2019

Давайте создадим модель и сохраним ее состояние.

class Model1(nn.Module):
    def __init__(self):
        super(Model1, self).__init__()

        self.encoder = nn.LSTM(100, 50)

    def forward(self):
        pass


model1 = Model1()
torch.save(model1.state_dict(), 'filename.pt') # saving model

Затем создайте вторую модель, которая имеет несколько слоев, общих для первой модели. Загрузите состояния первой модели и загрузите ее в общие слои второй модели.

class Model2(nn.Module):
    def __init__(self):
        super(Model2, self).__init__()

        self.encoder = nn.LSTM(100, 50)
        self.linear = nn.Linear(50, 200)

    def forward(self):
        pass


model1_dict = torch.load('filename.pt')
model2 = Model2()
model2_dict = model2.state_dict()

# 1. filter out unnecessary keys
filtered_dict = {k: v for k, v in model1_dict.items() if k in model2_dict}
# 2. overwrite entries in the existing state dict
model2_dict.update(filtered_dict)
# 3. load the new state dict
model2.load_state_dict(model2_dict)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...