Ну, может быть, его не должно быть в стандартной коллекции модулей, просто потому, что его можно определить очень просто:
class ParallelModule(nn.Sequential):
def __init__(self, *args):
super(ParallelModule, self).__init__( *args )
def forward(self, input):
output = []
for module in self:
output.append( module(input) )
return torch.cat( output, dim=1 )
Наследование "Parallel" из "Sequential" идеологически плохо, но работает хорошо. Теперь можно определить сети, как показано на рисунке, с помощью следующего кода:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = nn.Sequential(
nn.Conv2d( 1, 32, 3, padding=1 ), nn.ReLU(),
nn.Conv2d( 32, 64, 3, padding=1 ), nn.ReLU(),
nn.MaxPool2d( 3, stride=2 ), nn.Dropout2d( 0.25 ),
ParallelModule(
nn.Conv2d( 64, 64, 1 ),
nn.Sequential(
nn.Conv2d( 64, 64, 1 ), nn.ReLU(),
ParallelModule(
nn.Conv2d( 64, 32, (3,1), padding=(1,0) ),
nn.Conv2d( 64, 32, (1,3), padding=(0,1) ),
),
),
nn.Sequential(
nn.Conv2d( 64, 64, 1 ), nn.ReLU(),
nn.Conv2d( 64, 64, 3, padding=1 ), nn.ReLU(),
ParallelModule(
nn.Conv2d( 64, 32, (3,1), padding=(1,0) ),
nn.Conv2d( 64, 32, (1,3), padding=(0,1) ),
),
),
nn.Sequential(
#PrinterModule(),
nn.AvgPool2d( 3, stride=1, padding=1 ),
nn.Conv2d( 64, 64, 1 ),
),
),
nn.ReLU(),
nn.Conv2d( 256, 64, 1 ), nn.ReLU(),
nn.Conv2d( 64, 128, 3, padding=1 ), nn.ReLU(),
nn.MaxPool2d( 3, stride=2 ), nn.Dropout2d( 0.5 ),
nn.Flatten(),
nn.Linear( 4608, 128 ), nn.ReLU(),
nn.Linear( 128, 10 ), nn.LogSoftmax( dim=1 ),
)
def forward(self, x):
return self.net.forward( x )