DataParallel делает неожиданное поведение с циклами for - PullRequest
0 голосов
/ 02 июня 2019

У меня есть K Автоэнкодеры, обернутые ModuleList внутри настроенного PyTorch nn.Module. Когда я пытаюсь использовать DataParallel для распараллеливания моей модели между несколькими графическими процессорами, тело for-loop, которое я использую для обучения каждого автоэнкодера в forward, выполняется несколько раз, что вызывает серьезную проблему, потому что я добавляю в список внутри цикл, следовательно, имеет больше элементов, чем я ожидаю. Фрагмент заглушки может быть:

class CustomModel(nn.Module):

    def __init__(self,**kwargs):
        super(CustomModel, self).__init__()
        self.autoencoders = nn.ModuleList([AutoEncoder(**kwargs) for i in range(K)])

        self.encoders_output = []
        self.decoders_output = []

    def forward(self, input, input_length):
        for i, autoencoder in enumerate(self.autoencoders):
            decoder_output, encoder_output = autoencoder(input, input_length)
            # debug(i)
            self.encoders_output.append(encoder_output)
            self.decoders_output.append(decoder_output)
        self.encoders_output = []
        self.decoders_output = []
        return output

model = CustomModel(**kwargs)
model = nn.DataParallel(model).cuda()

Если вы попытаетесь раскомментировать debug(i), вы увидите, что каждый номер отлаживается (то есть печатается) много раз.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...