сохранение и загрузка нейронных сетей pytorch - PullRequest
0 голосов
/ 15 декабря 2018

Итак, я создал нейронную сеть, и я хотел бы сохранить ее и загрузить, когда захочу.В частности, я хочу делать снимки и выполнять обработку в реальном времени.Я использую нейронную сеть, созданную здесь

Я прочитал, что стандартный способ - создать сеть, затем использовать torch.save(net,'mynet'), чтобы сохранить ее, а затем загрузить с помощью torch.load('mynet'). * 1007.*

Однако, если я открою новый python3 терминал и использую:

>>import torch
>>torch.load('mynet')

Это дает мне ошибку:

  File "<stdin>", line 1, in <module>
  File "/home/tim/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 367, in load
    return _load(f, map_location, pickle_module)
  File "/home/tim/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 538, in _load
    result = unpickler.load()
AttributeError: Can't get attribute 'Net' on <module '__main__' (built-in)>

Я думаю, что это из-за отсутствия сетикласс определен.Добавление

import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 15, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(15, 15, 5)
        self.conv3 = nn.Conv2d(15, 10, 3)
        self.fc1 = nn.Linear(10*4*4, 100)
        self.fc2 = nn.Linear(100, 24)
        self.fc3 = nn.Linear(24, 4)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 10*4*4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

- это то, что вам нужно сделать, но зачем нам определять класс нейронной сети?Что если я загружу нейронную сеть с архитектурой, отличной от той, которую я указываю в классе, архитектура, определенная в классе, будет перезаписана?неужели объект, который я загружаю, содержит всю информацию об архитектуре и классе, инкапсулированную в нем?

Обновление: На самом деле, он даже не работает, когда я определяю класс Net.

1 Ответ

0 голосов
/ 15 декабря 2018

Пожалуйста, обратитесь к документации по семантике сериализации , которая сначала описывает предложенный подход, а затем тот, который вы использовали как "сериализованные данные, привязаны к конкретным классам и точной используемой структуре каталогов, поэтому они могутломаться различными способами при использовании в других проектах или после некоторых серьезных рефакторингов. "

Другими словами, вам нужно сохранять / загружать net.state_dict(), а не сам net.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...