net.load_state_dict (torch.load ('rnn_x_epoch.net')) не работает на процессоре - PullRequest
0 голосов
/ 30 января 2019

Я использую Pytorch для обучения нейронной сети.Когда я тренируюсь и тестирую на GPU, все работает нормально.Но когда я пытаюсь загрузить параметры модели в ЦП, используя:

net.load_state_dict(torch.load('rnn_x_epoch.net'))

, я получаю следующую ошибку:

RuntimeError: cuda runtime error (35) : CUDA driver version is insufficient for CUDA runtime version at torch/csrc/cuda/Module.cpp:51

Я искал ошибку, в основном из-за драйвера CUDAзависимость, но так как я работаю на процессоре, когда я получаю эту ошибку, это должно быть что-то еще, или, может быть, я что-то пропустил.Поскольку он работает нормально с использованием графического процессора, я мог бы просто запустить его на графическом процессоре, но я пытаюсь обучить сеть на графическом процессоре, сохранить параметры и затем загрузить его в режиме процессора для прогнозов.Я просто ищу способ загрузки параметров в режиме ЦП.

Я также пытался загрузить параметры:

check = torch.load('rnn_x_epoch.net')

Не сработало.

Я пытался сохранить параметры модели двумя способами, чтобы посмотреть, сработает ли какой-либо из них, но не получилось: 1)

checkpoint = {'n_hidden': net.n_hidden,
          'n_layers': net.n_layers,
          'state_dict': net.state_dict(),
          'tokens': net.chars}

with open('rnn_x_epoch.net', 'wb') as f:
    torch.save(checkpoint, f)

2)

torch.save(model.state_dict(), 'rnn_x_epoch.net')

TraceBack:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-e61f28013b35> in <module>()
----> 1 net.load_state_dict(torch.load('rnn_x_epoch.net'))

/opt/conda/lib/python3.6/site-packages/torch/serialization.py in load(f, map_location, pickle_module)
    301         f = open(f, 'rb')
    302     try:
--> 303         return _load(f, map_location, pickle_module)
    304     finally:
    305         if new_fd:

/opt/conda/lib/python3.6/site-packages/torch/serialization.py in _load(f, map_location, pickle_module)
    467     unpickler = pickle_module.Unpickler(f)
    468     unpickler.persistent_load = persistent_load
--> 469     result = unpickler.load()
    470 
    471     deserialized_storage_keys = pickle_module.load(f)

/opt/conda/lib/python3.6/site-packages/torch/serialization.py in persistent_load(saved_id)
    435             if root_key not in deserialized_objects:
    436                 deserialized_objects[root_key] = restore_location(
--> 437                     data_type(size), location)
    438             storage = deserialized_objects[root_key]
    439             if view_metadata is not None:

/opt/conda/lib/python3.6/site-packages/torch/serialization.py in default_restore_location(storage, location)
     86 def default_restore_location(storage, location):
     87     for _, _, fn in _package_registry:
---> 88         result = fn(storage, location)
     89         if result is not None:
     90             return result

/opt/conda/lib/python3.6/site-packages/torch/serialization.py in _cuda_deserialize(obj, location)
     68     if location.startswith('cuda'):
     69         device = max(int(location[5:]), 0)
---> 70         return obj.cuda(device)
     71 
     72 

/opt/conda/lib/python3.6/site-packages/torch/_utils.py in _cuda(self, device, non_blocking, **kwargs)
     66         if device is None:
     67             device = -1
---> 68     with torch.cuda.device(device):
     69         if self.is_sparse:
     70             new_type = getattr(torch.cuda.sparse, 
self.__class__.__name__)

/opt/conda/lib/python3.6/site-packages/torch/cuda/__init__.py in __enter__(self)
    223         if self.idx is -1:
    224             return
--> 225         self.prev_idx = torch._C._cuda_getDevice()
    226         if self.prev_idx != self.idx:
    227             torch._C._cuda_setDevice(self.idx)

RuntimeError: cuda runtime error (35) : CUDA driver version is insufficient for CUDA runtime version at torch/csrc/cuda/Module.cpp:51

Также может быть, что операции сохранения / загрузки в Pytorch предназначены только для режима GPU, но я не особо в этом убежден.

1 Ответ

0 голосов
/ 30 января 2019

Из документации PyTorch :

Когда вы вызываете torch.load() для файла, который содержит тензоры GPU, эти тензоры будут загружаться в GPU по умолчанию.

Чтобы загрузить модель на CPU, который был сохранен на GPU, вам нужно передать аргумент map_location в виде cpu в load функции следующим образом:

# Load all tensors onto the CPU
net.load_state_dict(torch.load('rnn_x_epoch.net', map_location=torch.device('cpu')))

При этом хранилища, лежащие в основе тензоров, динамически сопоставляются с устройством ЦП, используя аргумент map_location.Вы можете прочитать больше на официальных PyTorch руководствах .

Это также можно сделать следующим образом:

# Load all tensors onto the CPU, using a function
net.load_state_dict(torch.load('rnn_x_epoch.net', map_location=lambda storage, loc: storage))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...