Как я могу загрузить модель в PyTorch без переопределения модели? - PullRequest
1 голос
/ 16 января 2020

Я ищу способ сохранить модель Pytorch и загрузить ее без определения модели. Под этим я подразумеваю, что хочу сохранить свою модель, включая определение модели.

Например, мне бы хотелось иметь два сценария. Первый определит, обучит и сохранит модель. Второй будет загружать и прогнозировать модель без включения определения модели.

Метод, использующий torch.save(), torch.load(), требует, чтобы я включил определение модели в сценарий прогнозирования, но я хочу найти способ загрузить модель без переопределить его в сценарии.

1 Ответ

2 голосов
/ 16 января 2020

Вы можете попытаться экспортировать вашу модель в TorchScript , используя tracing . Это имеет ограничения. Благодаря тому, как PyTorch строит график вычислений модели на лету, если у вас есть какой-либо поток управления в вашей модели, то экспортированная модель может не полностью представлять ваш python модуль. TorchScript поддерживается только в PyTorch> = 1.0.0, хотя я бы порекомендовал использовать последнюю возможную версию.

Например, модель без какого-либо условного поведения подходит

from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(10)
        self.conv2 = nn.Conv2d(10, 20, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(20)
        self.fc = nn.Linear(20 * 4 * 4, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.bn1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.bn2(x)
        x = self.fc(x.flatten(1))
        return x

Мы можем экспортируйте это следующим образом

from torch import jit

net = Model()
# ... train your model

# put model in the mode you want to export (see bolded comment below)
net.eval()

# print example output
x = torch.ones(1, 3, 16, 16)
print(net(x))

# create TorchScript by tracing the computation graph with an example input
x = torch.ones(1, 3, 16, 16)
net_trace = jit.trace(net, x)
jit.save(net_trace, 'model.zip')

В случае успеха мы можем загрузить нашу модель в новый сценарий python без использования Model.

from torch import jit
net = jit.load('model.zip')

# print example output (should be same as during save)
x = torch.ones(1, 3, 16, 16)
print(net(x))

Загруженная модель также обучаема, однако загруженная модель будет вести себя только в том режиме, в котором она была экспортирована в . Например, в этом случае мы экспортировали нашу модель в режиме eval(), поэтому использование net.train() на загруженном модуле не даст никакого эффекта.


Control-flow

Модель как это, который имеет поведение, которое изменяется между проходами, не будет должным образом экспортироваться. Только код, оцененный во время jit.trace, будет экспортирован.

from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(10)
        self.conv2 = nn.Conv2d(10, 20, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(20)
        self.fca = nn.Linear(20 * 4 * 4, 2)
        self.fcb = nn.Linear(20 * 4 * 4, 2)

        self.use_a = True

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.bn1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.bn2(x)
        if self.use_a:
            x = self.fca(x.flatten(1))
        else:
            x = self.fcb(x.flatten(1))
        return x

Мы все еще можем экспортировать модель следующим образом

import torch
from torch import jit

net = Model()
# ... train your model

net.eval()

# print example input
x = torch.ones(1, 3, 16, 16)
net.use_a = True
print('a:', net(x))
net.use_a = False
print('b:', net(x))

# save model
x = torch.ones(1, 3, 16, 16)
net_trace = jit.trace(net, x)
jit.save(net_trace, "model.ts")

В этом случае выходные данные примера будут

a: tensor([[-0.0959,  0.0657]], grad_fn=<AddmmBackward>)
b: tensor([[ 0.1437, -0.0033]], grad_fn=<AddmmBackward>)

Однако загрузка

import torch
from torch import jit

net = jit.load("model.ts")

# will not match the output from before
x = torch.ones(1, 3, 16, 16)
net.use_a = True
print('a:', net(x))
net.use_a = False
print('b:', net(x))

приводит к

a: tensor([[ 0.1437, -0.0033]], grad_fn=<DifferentiableGraphBackward>)
b: tensor([[ 0.1437, -0.0033]], grad_fn=<DifferentiableGraphBackward>)

Обратите внимание, что лог c ветви "a" отсутствует, поскольку net.use_a было False когда был вызван jit.trace.


Сценарии

Эти ограничения могут быть преодолены, но требуют определенных усилий с вашей стороны. Вы можете использовать функциональность scripting , чтобы гарантировать, что все логи c экспортированы.

...