Pytorch 1.0: что делает net.to (устройство) в nn.DataParallel? - PullRequest
0 голосов
/ 24 апреля 2019

Следующий код из учебника для паралича данных Pytorch выглядит странно для меня:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

model.to(device)

Насколько я знаю, mode.to(device) копирует данные в графический процессор.

DataParallel автоматически разбивает ваши данные и отправляет заказы на работу нескольким моделям на нескольких графических процессорах.После того, как каждая модель завершит свою работу, DataParallel собирает и объединяет результаты, прежде чем вернуть их вам.

Если DataParallel выполняет копирование, что делает to(device) здесь?

1 Ответ

0 голосов
/ 25 апреля 2019

Они добавляют несколько строк в учебник , чтобы объяснить nn.DataParallel.

DataParallel автоматически разбивает ваши данные и отправляет заказы на выполнение работ нескольким моделям на разных графических процессорах, используя данные. После того как каждая модель завершит свою работу, DataParallel собирает и объединяет результаты для вас.

Из приведенной выше цитаты можно понять, что nn.DataParallel - это просто класс-обертка, информирующий model.cuda() о необходимости сделать несколько копий для графических процессоров.

В моем случае на моем ноутбуке нет графического процессора. Я до сих пор без проблем звоню nn.DataParallel().

import torch
import torchvision

model = torchvision.models.alexnet()
model = torch.nn.DataParallel(model)
# No error appears if I don't move the model to `cuda`
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...