У меня нет опыта работы с DataParallel
, но я думаю, что это может быть потому, что ваш тензор не является частью параметров модели. Вы можете сделать это: .
Я могу себе представить, что DataParallel
использует параметры модели для перемещения их к соответствующему графическому процессору.
См. этот ответ для получения дополнительной информации о разнице между тензором резака. и torch.nn.Parameter
.
Если вы не хотите, чтобы значения тензора обновлялись путем обратного распространения во время обучения, вы можете добавить requires_grad=False
.
Другой способ, который может сработать, - переопределить метод to
и инициализировать тензор в прямом проходе:
class custom_model(torch.nn.Module):
def __init__(self):
super(custom_model, self).__init__()
self.layer = torch.nn.Linear(10,5)
def forward(self, x):
return self.layer(x) @ torch.ones(5,1, device=self.device)
def to(self, device: str):
new_self = super(custom_model, self).to(device)
new_self.device = device
return new_self
или что-то вроде этого:
class custom_model(torch.nn.Module):
def __init__(self, device:str):
super(custom_model, self).__init__()
self.layer = torch.nn.Linear(10,5)
self.weight = torch.ones(5,1, device=device)
def forward(self, x):
return self.layer(x) @ self.weight
def to(self, device: str):
new_self = super(custom_model, self).to(device)
new_self.device = device
new_self.weight = torch.ones(5,1, device=device)
return new_self