Вы можете попытаться экспортировать вашу модель в TorchScript , используя tracing . Это имеет ограничения. Благодаря тому, как PyTorch строит график вычислений модели на лету, если у вас есть какой-либо поток управления в вашей модели, то экспортированная модель может не полностью представлять ваш python модуль. TorchScript поддерживается только в PyTorch> = 1.0.0, хотя я бы порекомендовал использовать последнюю возможную версию.
Например, модель без какого-либо условного поведения подходит
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 10, 3, padding=1)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, 3, padding=1)
self.bn2 = nn.BatchNorm2d(20)
self.fc = nn.Linear(20 * 4 * 4, 2)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.bn1(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.bn2(x)
x = self.fc(x.flatten(1))
return x
Мы можем экспортируйте это следующим образом
from torch import jit
net = Model()
# ... train your model
# put model in the mode you want to export (see bolded comment below)
net.eval()
# print example output
x = torch.ones(1, 3, 16, 16)
print(net(x))
# create TorchScript by tracing the computation graph with an example input
x = torch.ones(1, 3, 16, 16)
net_trace = jit.trace(net, x)
jit.save(net_trace, 'model.zip')
В случае успеха мы можем загрузить нашу модель в новый сценарий python без использования Model
.
from torch import jit
net = jit.load('model.zip')
# print example output (should be same as during save)
x = torch.ones(1, 3, 16, 16)
print(net(x))
Загруженная модель также обучаема, однако загруженная модель будет вести себя только в том режиме, в котором она была экспортирована в . Например, в этом случае мы экспортировали нашу модель в режиме eval()
, поэтому использование net.train()
на загруженном модуле не даст никакого эффекта.
Control-flow
Модель как это, который имеет поведение, которое изменяется между проходами, не будет должным образом экспортироваться. Только код, оцененный во время jit.trace
, будет экспортирован.
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 10, 3, padding=1)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, 3, padding=1)
self.bn2 = nn.BatchNorm2d(20)
self.fca = nn.Linear(20 * 4 * 4, 2)
self.fcb = nn.Linear(20 * 4 * 4, 2)
self.use_a = True
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.bn1(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.bn2(x)
if self.use_a:
x = self.fca(x.flatten(1))
else:
x = self.fcb(x.flatten(1))
return x
Мы все еще можем экспортировать модель следующим образом
import torch
from torch import jit
net = Model()
# ... train your model
net.eval()
# print example input
x = torch.ones(1, 3, 16, 16)
net.use_a = True
print('a:', net(x))
net.use_a = False
print('b:', net(x))
# save model
x = torch.ones(1, 3, 16, 16)
net_trace = jit.trace(net, x)
jit.save(net_trace, "model.ts")
В этом случае выходные данные примера будут
a: tensor([[-0.0959, 0.0657]], grad_fn=<AddmmBackward>)
b: tensor([[ 0.1437, -0.0033]], grad_fn=<AddmmBackward>)
Однако загрузка
import torch
from torch import jit
net = jit.load("model.ts")
# will not match the output from before
x = torch.ones(1, 3, 16, 16)
net.use_a = True
print('a:', net(x))
net.use_a = False
print('b:', net(x))
приводит к
a: tensor([[ 0.1437, -0.0033]], grad_fn=<DifferentiableGraphBackward>)
b: tensor([[ 0.1437, -0.0033]], grad_fn=<DifferentiableGraphBackward>)
Обратите внимание, что лог c ветви "a" отсутствует, поскольку net.use_a
было False
когда был вызван jit.trace
.
Сценарии
Эти ограничения могут быть преодолены, но требуют определенных усилий с вашей стороны. Вы можете использовать функциональность scripting , чтобы гарантировать, что все логи c экспортированы.