Лучшая практика для передачи имени устройства PyTorch в модель - PullRequest
0 голосов
/ 29 апреля 2020

В настоящее время я разделил train.py на model.py для моего проекта глубокого обучения.

Таким образом, для наборов данных они отправляются на устройство cuda внутри epoch for loop, как показано ниже.

train.py

...
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = MyNet(~).to(device)
...
for batch_data in train_loader:
    s0 = batch_data[0].to(device)
    s1 = batch_data[1].to(device)
    pred = model(s0, s1)

Однако внутри моей модели (в model.py) также требуется доступ к переменной устройства для пропуска соединения, как метод. Чтобы создать новую копию скрытого блока (для остаточного соединения)

model.py

class MyNet(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super(MyNet, self).__init__()
        self.conv1 = GCNConv(in_feats, hid_feats)
        ...

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x1 = copy.copy(x.float())
        x = self.conv1(x, edge_index)
        skip_conn = torch.zeros(len(data.batch), x1.size(1)).to(device)  # <--
        (some opps for x1 -> skip_conn)
        x = torch.cat((x, skip_conn), 1)

В этом случае в настоящее время я передаю device в качестве параметра, однако, я считаю, это не лучшая практика.

  1. Где должна быть лучшая практика для отправки набора данных в CUDA?
  2. В случае нескольких сценариев необходим доступ к device, как это можно сделать? Я справлюсь с этим? (параметр, глобальная переменная?)

1 Ответ

1 голос
/ 29 апреля 2020

Вы можете добавить новый атрибут к MyModel, чтобы сохранить информацию device и использовать ее при инициализации skip_conn.

class MyNet(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, device): # <--
    super(MyNet, self).__init__()
    self.conv1 = GCNConv(in_feats, hid_feats)
    self.device = device # <--
    self.to(self.device) # <--
    ...

def forward(self, data):
    x, edge_index = data.x, data.edge_index
    x1 = copy.copy(x.float())
    x = self.conv1(x, edge_index)
    skip_conn = torch.zeros(len(data.batch), x1.size(1), device=self.device)  # <--
    (some opps for x1 -> skip_conn)
    x = torch.cat((x, skip_conn), 1)

Обратите внимание, что в этом примере MyNet отвечает для всех логи устройства c, включая вызов .to(device). Таким образом, мы инкапсулируем все связанные с моделью устройства управления в самом классе модели.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...