Совместное использование параметров между определенными слоями разных экземпляров одной и той же модели Pytorch - PullRequest
0 голосов
/ 01 сентября 2018

У меня есть модель pytorch с несколькими слоями, которая выглядит примерно так

class CNN(nn.Module):
    def __init__(self):
        super(CNN).__init__()
        self.layer1 = nn.Conv2d(#parameters)
        self.layer2 = nn.Conv2d(#different_parameters)
        self.layer3 = nn.Conv2d(#other_parameters)
        self.layer4 = nn.Conv2d(#final_parameters)

    def forward(self, x):
        out1 = self.layer2(F.relu(self.layer1(x)))
        out2 = self.layer4(F.relu(self.layer3(x)))
        return torch.cat((out1, out2), 0)

Затем я хочу создать несколько экземпляров этого класса (cnn1, cnn2) и поделиться параметрами первого пути (layer1, layer2) между экземплярами, оставив другие параметры раздельными.

Есть ли оптимальный / поддерживаемый способ сделать это?

1 Ответ

0 голосов
/ 03 сентября 2018

Просто соберите layer1, layer2 как независимый модуль.

один пример: model1 и mode2 имеют частный полностью подключенный уровень, но совместно используют слой conv2d, то есть экстрактор функций

feature_ex = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1, 6, 5)),
                                        ('relu1', nn.ReLU()),
                                        ('maxpool1', nn.MaxPool2d((2, 2))),
                                        ('conv2', nn.Conv2d(6, 16, 5)),
                                        ('relu2', nn.ReLU()),
                                        ('maxpool2', nn.MaxPool2d(2))
                                        ]))


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = feature_ex(x)   # [1]
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)     # [2]

        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:       # Get the products
            num_features *= s
        return num_features


model1 = Net()
model2 = Net()

img = torch.randn(10, 1, 32, 32)
out1 = model1.forward(img)
out2 = model2.forward(img)

# [1]
# print(np.allclose(out1.detach().numpy(), out2.detach().numpy()))
# output: True

# [2]
print(np.allclose(out1.detach().numpy(), out2.detach().numpy()))
# output: False
...