Как перенести вес собственной модели в ту же сеть, но с разным количеством классов в последнем слое? - PullRequest
0 голосов
/ 23 декабря 2018

У меня есть своя сеть в Pytorch.Сначала он обучен бинарному классификатору (2 класса).После 10 тысяч эпох я получил тренировочный вес как 10000_model.pth.Теперь я хочу использовать модель для задачи классификатора 4 классов, используя ту же сеть.Таким образом, я хочу перевести все обученные веса в двоичном классификаторе в задачу с 4 классами, без слоя, который будет случайной инициализацией.Как я мог это сделать?Это моя модель

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.conv_classify= nn.Conv2d(50, 2, 1, 1, bias=True) # number of class

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_classify(x))
        return x

Это то, что я сделал

model = Net ()
checkpoint_dict = torch.load('10000_model.pth')        
pretrained_dict = checkpoint_dict['state_dict']
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

На данный момент я должен вручную удалить pretrained_dict по имени.

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
pretrained_dict.pop('conv_classify.weight', None)
pretrained_dict.pop('conv_classify.bias', None)

Это означает, что pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} ничего не делает.

Что не так?Я использую Pytorch 1.0.Спасибо

1 Ответ

0 голосов
/ 23 декабря 2018

Обе сети имеют одинаковые слои и, следовательно, одинаковые ключи в state_dict, поэтому действительно

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

ничего не делает.Разница между ними заключается в весовых тензорах (их форма), а не в их именах.Другими словами, вы можете различить их по [v.shape for v in model.state_dict().values()], но не model.state_dict().keys().Ваш «обходной» подход верен.Если вы хотите сделать это немного менее ручным, я бы использовал

merged_dict = {}
for key in model_dict.keys():
    if 'conv_classify' in key: # or perhaps a more complex criterion
        merged_dict[key] = model_dict[key]
    else:
        merged_dict[key] = pretrained_dict[key]
...