Итак, я создал нейронную сеть, и я хотел бы сохранить ее и загрузить, когда захочу.В частности, я хочу делать снимки и выполнять обработку в реальном времени.Я использую нейронную сеть, созданную здесь
Я прочитал, что стандартный способ - создать сеть, затем использовать torch.save(net,'mynet')
, чтобы сохранить ее, а затем загрузить с помощью torch.load('mynet')
. * 1007.*
Однако, если я открою новый python3
терминал и использую:
>>import torch
>>torch.load('mynet')
Это дает мне ошибку:
File "<stdin>", line 1, in <module>
File "/home/tim/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 367, in load
return _load(f, map_location, pickle_module)
File "/home/tim/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 538, in _load
result = unpickler.load()
AttributeError: Can't get attribute 'Net' on <module '__main__' (built-in)>
Я думаю, что это из-за отсутствия сетикласс определен.Добавление
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 15, 3)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(15, 15, 5)
self.conv3 = nn.Conv2d(15, 10, 3)
self.fc1 = nn.Linear(10*4*4, 100)
self.fc2 = nn.Linear(100, 24)
self.fc3 = nn.Linear(24, 4)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 10*4*4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
- это то, что вам нужно сделать, но зачем нам определять класс нейронной сети?Что если я загружу нейронную сеть с архитектурой, отличной от той, которую я указываю в классе, архитектура, определенная в классе, будет перезаписана?неужели объект, который я загружаю, содержит всю информацию об архитектуре и классе, инкапсулированную в нем?
Обновление: На самом деле, он даже не работает, когда я определяю класс Net.