У меня есть K
Автоэнкодеры, обернутые ModuleList
внутри настроенного PyTorch nn.Module
. Когда я пытаюсь использовать DataParallel
для распараллеливания моей модели между несколькими графическими процессорами, тело for-loop
, которое я использую для обучения каждого автоэнкодера в forward
, выполняется несколько раз, что вызывает серьезную проблему, потому что я добавляю в список внутри цикл, следовательно, имеет больше элементов, чем я ожидаю. Фрагмент заглушки может быть:
class CustomModel(nn.Module):
def __init__(self,**kwargs):
super(CustomModel, self).__init__()
self.autoencoders = nn.ModuleList([AutoEncoder(**kwargs) for i in range(K)])
self.encoders_output = []
self.decoders_output = []
def forward(self, input, input_length):
for i, autoencoder in enumerate(self.autoencoders):
decoder_output, encoder_output = autoencoder(input, input_length)
# debug(i)
self.encoders_output.append(encoder_output)
self.decoders_output.append(decoder_output)
self.encoders_output = []
self.decoders_output = []
return output
model = CustomModel(**kwargs)
model = nn.DataParallel(model).cuda()
Если вы попытаетесь раскомментировать debug(i)
, вы увидите, что каждый номер отлаживается (то есть печатается) много раз.