pytorch single-gpu to multi-gpu model - PullRequest
       1

pytorch single-gpu to multi-gpu model

0 голосов
/ 24 апреля 2020

У меня есть модель Pytorch, которая работает на одном устройстве GPU. Согласно документации Pytorch, для запуска на нескольких GPU первым шагом является наследование от класса nn.DataParallel, который определен здесь .

import torch.nn as nn


class DataParallelModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.block1 = nn.Linear(10, 20)

        # wrap block2 in DataParallel
        self.block2 = nn.Linear(20, 20)
        self.block2 = nn.DataParallel(self.block2)

        self.block3 = nn.Linear(20, 20)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x

и DistributedDataParallel

Основная идея c заключается в передаче промежуточных данных между устройствами в проходе forward, определив блоки на разных устройства, подобные приведенному в этом примере, где вы превращаете модуль с одним графическим процессором в 2-GPU:

from torchvision.models.resnet import ResNet, Bottleneck

num_classes = 1000


class ModelParallelResNet50(ResNet):
    def __init__(self, *args, **kwargs):
        super(ModelParallelResNet50, self).__init__(
            Bottleneck, [3, 4, 6, 3], num_classes=num_classes, *args, **kwargs)

        self.seq1 = nn.Sequential(
            self.conv1,
            self.bn1,
            self.relu,
            self.maxpool,

            self.layer1,
            self.layer2
        ).to('cuda:0')

        self.seq2 = nn.Sequential(
            self.layer3,
            self.layer4,
            self.avgpool,
        ).to('cuda:1')

        self.fc.to('cuda:1')

    def forward(self, x):
        x = self.seq2(self.seq1(x).to('cuda:1'))
        return self.fc(x.view(x.size(0), -1))

в промежуточном промежуточном результате результаты от self.seq1, обученного на cuda:0, передаются на cuda:1 перед передачей на seq2 и т. д. c.

Теперь мне не ясно, как сделать этот модуль обобщенным c для случая с несколькими GPU, как в случае с 4 экземплярами GPU.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...