Конвертация модели pytorch в nn.Модуль для экспорта в onnx для Lens Studio - PullRequest
1 голос
/ 07 августа 2020

Я пытаюсь преобразовать pix2pix в pb или onnx, который может работать в Lens Studio. Студия Lens предъявляет к моделям строгие требования. Я пытаюсь экспортировать эту модель pytorch в onnx, используя это руководство, предоставленное студией объектива. Проблема в том, что модель pytorch , найденная здесь, использует свой собственный базовый класс, когда в Например, он использует Module.nn и поэтому не имеет методов / переменных, которые должна запускать функция torch.onnx.export. Пока что я столкнулся с отсутствием переменной с именем training и метода с именем train

Стоит ли пытаться изменить базовую модель, или я должен попытаться построить ее из поцарапать с помощью nn.Module? Есть ли способ заставить модель pix2pix унаследовать как от абстрактного базового класса, так и от nn.module? Я не понимаю ситуацию? Причина, по которой я хочу сделать это с помощью учебника по студии линз, заключается в том, что я получил его для экспорта onnx разными способами, но Lens Studio не принимает их по разным причинам.

Также я впервые задаю вопрос SO (после 6 лет программирования), дайте мне знать, если я сделаю ошибки, и я смогу их исправить. Спасибо.

Это важный код из учебника по созданию модели pytorch для Lens Studio:

import torch
import torch.nn as nn
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Conv2d(in_channels=3, out_channels=1, 
                               kernel_size=3, stride=2, padding=1)
        
    def forward(self, x):
        out = self.layer(x)
        out = nn.functional.interpolate(out, scale_factor=2, 
                                        mode='bilinear', align_corners=True)
        out = torch.nn.functional.softmax(out, dim=1)
        return out

Я не собираюсь включать весь код из модели pytorch b c он большой, но начало baseModel.py -

import os
import torch
from collections import OrderedDict
from abc import ABC, abstractmethod
from . import networks


class BaseModel(ABC):
    """This class is an abstract base class (ABC) for models.
    To create a subclass, you need to implement the following five functions:
        -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
        -- <set_input>:                     unpack data from dataset and apply preprocessing.
        -- <forward>:                       produce intermediate results.
        -- <optimize_parameters>:           calculate losses, gradients, and update network weights.
        -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
    """

    def __init__(self, opt):
        """Initialize the BaseModel class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions

        When creating your custom class, you need to implement your own initialization.
        In this function, you should first call <BaseModel.__init__(self, opt)>
        Then, you need to define four lists:
            -- self.loss_names (str list):          specify the training losses that you want to plot and save.
            -- self.model_names (str list):         define networks used in our training.
            -- self.visual_names (str list):        specify the images that you want to display and save.
            -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
        """
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')  # get device name: CPU or GPU
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)  # save all the checkpoints to save_dir
        if opt.preprocess != 'scale_width':  # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
            torch.backends.cudnn.benchmark = True
        self.loss_names = []
        self.model_names = []
        self.visual_names = []
        self.optimizers = []
        self.image_paths = []
        self.metric = 0  # used for learning rate policy 'plateau'

, а для pix2pix_model.py

import torch
from .base_model import BaseModel
from . import networks


class Pix2PixModel(BaseModel):
    """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.

    The model training requires '--dataset_mode aligned' dataset.
    By default, it uses a '--netG unet256' U-Net generator,
    a '--netD basic' discriminator (PatchGAN),
    and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).

    pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        For pix2pix, we do not use image buffer
        The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
        By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
        """
        # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
        parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned')
        if is_train:
            parser.set_defaults(pool_size=0, gan_mode='vanilla')
            parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')

        return parser

    def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """

(Также обратите внимание, если вы видите это, и это выглядит нелегко выход, дайте мне знать, я знаю, каково это видеть кого-то, начинающего что-то делать, кто слишком рано заходит слишком глубоко)

1 Ответ

0 голосов
/ 07 августа 2020

Вы определенно можете наследовать свою модель как от базового класса, так и от torch.nn.Module (python допускает множественное наследование). Однако вы должны позаботиться о конфликтах, если оба унаследованных класса имеют функции с одинаковыми именами (я вижу по крайней мере одну: их база предоставляет функцию eval и, следовательно, nn.module).

Однако, поскольку вы CycleGan не нужен, а большая часть кода совместима с их обучающей средой, вам, вероятно, лучше просто повторно реализовать pix2pix. Просто украдите код, унаследуйте его от nn.Module, скопируйте и вставьте полезные / обязательные функции из базового класса и все переведите в чистый код pytorch. У вас уже есть функция пересылки (которая является единственным требованием для модуля pytorch).

Все подсети, которые они используют (например, блоки re snet), похоже, унаследованы от nn.Module, так что ничего нет изменить здесь (хотя еще раз проверьте это)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...