Pytorch DataParallel не работает, когда модель содержит тензорную операцию - PullRequest
1 голос
/ 22 марта 2020

Если моя модель содержит только nn.Module слоев, таких как nn.Linear, nn.DataParallel работает нормально.

x = torch.randn(100,10)

class normal_model(torch.nn.Module):
    def __init__(self):
        super(normal_model, self).__init__()
        self.layer = torch.nn.Linear(10,1)

    def forward(self, x):
        return self.layer(x)

model = normal_model()
model = nn.DataParallel(model.to('cuda:0'))
model(x)

Однако, когда моя модель содержит тензорную операцию, такую ​​как

class custom_model(torch.nn.Module):
    def __init__(self):
        super(custom_model, self).__init__()
        self.layer = torch.nn.Linear(10,5)
        self.weight = torch.ones(5,1, device='cuda:0')
    def forward(self, x):
        return self.layer(x) @ self.weight

model = custom_model()
model = torch.nn.DataParallel(model.to('cuda:0'))
model(x) 

Это дает мне следующую ошибку

RuntimeError: Обнаружено RuntimeError в реплике 1 на устройстве 1. Исходная трассировка (последний вызов был последним): файл "/ opt / conda / lib / python3 .6 / site-packages / torch / nn /rallel / parallel_apply.py ", строка 60, в файле _worker output = module (* input, ** kwargs)" /opt/conda/lib/python3.6 /site-packages/torch/nn/modules/module.py ", строка 541, в вызов result = self.forward (* input, ** kwargs) Файл" ", строка 7, в прямом возврате self.layer (x) @ self.weight RuntimeError: аргументы располагаются на разных графических процессорах в /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:277

Как избежать этой ошибки, когда у нас есть некоторые тензорные операции в нашей модели?

1 Ответ

1 голос
/ 22 марта 2020

У меня нет опыта работы с DataParallel, но я думаю, что это может быть потому, что ваш тензор не является частью параметров модели. Вы можете сделать это: .

Я могу себе представить, что DataParallel использует параметры модели для перемещения их к соответствующему графическому процессору.

См. этот ответ для получения дополнительной информации о разнице между тензором резака. и torch.nn.Parameter.

Если вы не хотите, чтобы значения тензора обновлялись путем обратного распространения во время обучения, вы можете добавить requires_grad=False.

Другой способ, который может сработать, - переопределить метод to и инициализировать тензор в прямом проходе:

class custom_model(torch.nn.Module):
    def __init__(self):
        super(custom_model, self).__init__()
        self.layer = torch.nn.Linear(10,5)
    def forward(self, x):
        return self.layer(x) @ torch.ones(5,1, device=self.device)
    def to(self, device: str):
        new_self = super(custom_model, self).to(device)
        new_self.device = device
        return new_self

или что-то вроде этого:

class custom_model(torch.nn.Module):
    def __init__(self, device:str):
        super(custom_model, self).__init__()
        self.layer = torch.nn.Linear(10,5)
        self.weight = torch.ones(5,1, device=device)
    def forward(self, x):
        return self.layer(x) @ self.weight
    def to(self, device: str):
        new_self = super(custom_model, self).to(device)
        new_self.device = device
        new_self.weight = torch.ones(5,1, device=device)
        return new_self
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...