Сети PyTorch в модели - PullRequest
       3

Сети PyTorch в модели

0 голосов
/ 16 июня 2020

Я хотел бы определить сеть, состоящую из множества шаблонов. Ниже под Определения сети приведен упрощенный пример, в котором первое определение сети используется в качестве шаблона во втором. Это не работает - когда я инициализирую мой оптимизатор, он говорит, что параметры сети пусты! Как мне это сделать правильно? Сеть, которая мне нужна, очень сложна.

Основная функция

if __name__ == "__main__":

myNet       = Network().cuda().train()
optimizer   = optim.SGD(myNet.parameters(), lr=0.01, momentum=0.9)

Определения сети:

class NetworkTemplate(nn.Module):


def __init__(self):
    super(NetworkTemplate, self).__init__()
    self.conv1 = nn.Conv2d(1, 3, kernel_size=1, bias=False)
    self.bn1 = nn.BatchNorm2d(3)

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)

    return x

class Network(nn.Module):


def __init__(self, nNets):
    super(Network, self).__init__()

    self.nets = []
    for curNet in range(nNets):
        self.nets.append(NetworkTemplate())

    def forward(self, x):

        for curNet in self.nets:
            x = curNet(x)

        return x

1 Ответ

2 голосов
/ 16 июня 2020

Просто используйте torch.nn.Sequential? Например, self.nets=torch.nn.Sequential(*self.nets) после того, как вы заполнили self.nets, а затем вызывали return self.nets(x) в своей функции forward?

Если вы хотите сделать что-то более сложное, вы можете поместить все сети в torch.nn.ModuleList, но вы В этом случае вам нужно будет вручную вызвать их в вашем методе forward (но это может быть сложнее, чем просто последовательный).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...