Либо слишком мало, либо слишком много аргументов для nn.Sequential - PullRequest
0 голосов
/ 05 апреля 2020

Я новичок в PyTorch, поэтому, пожалуйста, извините за мой глупый вопрос.

Я определяю nn.Sequential при инициализации моего объекта Encoder следующим образом:

self.list_of_blocks = [EncoderBlock(n_features, n_heads, n_hidden, dropout) for _ in range(n_blocks)]
self.blocks = nn.Sequential(*self.list_of_blocks)

Форвард EncoderBlock выглядит так:

def forward(self, x, mask):

В forward () моего кодировщика я пытаюсь сделать:

z0 = self.blocks(z0, mask)

Я ожидаю, что nn.Sequential передаст эти два аргумента отдельным блокам.

Однако я получаю

TypeError: forward() takes 2 positional arguments but 3 were given

Когда я пытаюсь:

z0 = self.blocks(z0)

Я получаю (понятно):

TypeError: forward() takes 2 positional arguments but only 1 was given

Когда я делаю не используйте nn.Sequential и просто выполняйте один EncoderBlock за другим, он работает:

for i in range(self.n_blocks):
     z0 = self.list_of_blocks[i](z0, mask)

Вопрос : что я делаю неправильно и как правильно использовать nn.Sequential в этом случае

1 Ответ

1 голос
/ 06 апреля 2020

Последовательный в общем случае не работает с несколькими входами и выходами.

Это часто обсуждаемые темы c, см. Форум PyTorch и проблемы GitHub # 1908 или # 9979 .

Вы можете определить свою собственную версию последовательного. Предполагая, что маска одинакова для всех ваших блоков кодировщика (например, как в сетях Transformer), вы можете сделать:

class MaskedSequential(nn.Sequential):
    def forward(self, x, mask):
        for module in self._modules.values():
            x = module(x, mask)
        return inputs

Или, если ваши EncoderBlock s возвращают кортежи, вы можете использовать более общие решение, предложенное в одной из проблем GitHub :

class MySequential(nn.Sequential):
    def forward(self, *inputs):
        for module in self._modules.values():
            if type(inputs) == tuple:
                inputs = module(*inputs)
            else:
                inputs = module(inputs)
        return inputs
...