Параллельный аналог torch.nn. Последовательный контейнер - PullRequest
0 голосов
/ 08 марта 2020

Просто интересно, почему я не могу найти subj в torch.nn? nn.Sequential довольно удобен, он позволяет определять сети в одном месте, четкие и наглядные, но ограниченные очень простыми! С помощью параллельного аналога (и небольшой помощи «идентичных» узлов для остаточных соединений) он образует законченный метод для построения любого комбинаторного пути net с прямой связью. Я что-то пропустил?

1 Ответ

0 голосов
/ 17 марта 2020

Ну, может быть, его не должно быть в стандартной коллекции модулей, просто потому, что его можно определить очень просто:

class ParallelModule(nn.Sequential):
    def __init__(self, *args):
        super(ParallelModule, self).__init__( *args )

    def forward(self, input):
        output = []
        for module in self:
            output.append( module(input) )
        return torch.cat( output, dim=1 )

Наследование "Parallel" из "Sequential" идеологически плохо, но работает хорошо. Теперь можно определить сети, как показано на рисунке, с помощью следующего кода: Network image by torchviz:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(  1, 32, 3, padding=1 ),        nn.ReLU(),
            nn.Conv2d( 32, 64, 3, padding=1 ),        nn.ReLU(),
            nn.MaxPool2d( 3, stride=2 ),              nn.Dropout2d( 0.25 ),

            ParallelModule(
                nn.Conv2d(  64, 64, 1 ),
                nn.Sequential( 
                    nn.Conv2d(  64, 64, 1 ),          nn.ReLU(),
                    ParallelModule(
                        nn.Conv2d(  64, 32, (3,1), padding=(1,0) ),
                        nn.Conv2d(  64, 32, (1,3), padding=(0,1) ),
                    ),
                ),
                nn.Sequential( 
                    nn.Conv2d(  64, 64, 1 ),          nn.ReLU(),
                    nn.Conv2d(  64, 64, 3, padding=1 ), nn.ReLU(),
                    ParallelModule(
                        nn.Conv2d(  64, 32, (3,1), padding=(1,0) ),
                        nn.Conv2d(  64, 32, (1,3), padding=(0,1) ),
                    ),
                ),
                nn.Sequential( 
                    #PrinterModule(),
                    nn.AvgPool2d( 3, stride=1, padding=1 ),
                    nn.Conv2d(  64, 64, 1 ),
                ),
            ),
            nn.ReLU(),
            nn.Conv2d( 256, 64, 1 ),                  nn.ReLU(),

            nn.Conv2d( 64, 128, 3, padding=1 ),       nn.ReLU(),
            nn.MaxPool2d( 3, stride=2 ),              nn.Dropout2d( 0.5 ),
            nn.Flatten(),
            nn.Linear( 4608, 128 ),                   nn.ReLU(),
            nn.Linear(  128,  10 ),                   nn.LogSoftmax( dim=1 ),
        )

    def forward(self, x):
        return self.net.forward( x )
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...