экспорт гирлянды в кера - PullRequest
0 голосов
/ 29 мая 2018

У меня есть предварительно обученная модель pytorch, вес которой мне нужно использовать в другой модели keras.

Я пытаюсь pytorch2keras репозиторий GitHub для преобразования веса Pytorch .pth в керас .h5

моя модель выглядит следующим образом:

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

class BaseNetwork(nn.Module):
def __init__(self, name, channels=1):
    super(BaseNetwork, self).__init__()
    self._name = name
    self._channels = channels

def name(self):
    return self._name

def initialize_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform(m.weight, gain=np.sqrt(2))
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.01)
            m.bias.data.zero_()
class ImageWiseNetwork(BaseNetwork):
def __init__(self, channels=1):
    super(ImageWiseNetwork, self).__init__('iw' + str(channels), channels)

    self.features = nn.Sequential(
        # Block 1
        nn.Conv2d(in_channels=12 * channels, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=2),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),

        # Block 2
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=128, out_channels=128, kernel_size=2, stride=2),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),

        nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1),
    )

    self.classifier = nn.Sequential(
        nn.Linear(1 * 16 * 16, 128),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5, inplace=True),

        nn.Linear(128, 128),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5, inplace=True),

        nn.Linear(128, 64),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5, inplace=True),

        nn.Linear(64, 4),
    )

    self.initialize_weights()

def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)
    x = F.log_softmax(x, dim=1)
    return x

создал объектную модель с model = ImageWiseNetwork(), а затем загрузил обученные веса как:

model.load_state_dict(torch.load('Path\to\weights\weights_iw1.pth'))

, затем

input_np = np.random.uniform(0, 1, (3, 12, 512, 512))
input_var = Variable(torch.FloatTensor(input_np))and the

, а затем

from converter import pytorch_to_keras
k_model = pytorch_to_keras(model, input_var, (3, 512, 512,), verbose=True)  

Я получаю следующую ошибку с трассировкой:

> > RuntimeError                              Traceback (most recent call last) <ipython-input-29-7c5d264109b9> in <module>()
>       1 from converter import pytorch_to_keras
> ----> 2 k_model = pytorch_to_keras(model, input_var, (3, 512, 512,), verbose=True)
> 
> ~\Downloads\pytorch2keras-master\pytorch2keras-master\pytorch2keras\converter.py
> in pytorch_to_keras(model, args, input_shape, change_ordering,
> training, verbose)
>      84 
>      85     with set_training(model, training):
> ---> 86         trace, torch_out = torch.jit.get_trace_graph(model, args)
>      87 
>      88     if orig_state_dict_keys != _unique_state_dict(model).keys():
> 
> ~\Anaconda3\lib\site-packages\torch\jit\__init__.py in
> get_trace_graph(f, args, kwargs, nderivs)
>     253     if not isinstance(args, tuple):
>     254         args = (args,)
> --> 255     return LegacyTracedModule(f, nderivs=nderivs)(*args, **kwargs)
>     256 
>     257 
> 
> ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in
> __call__(self, *input, **kwargs)
>     487             hook(self, input)
>     488         if torch.jit._tracing:
> --> 489             result = self._slow_forward(*input, **kwargs)
>     490         else:
>     491             result = self.forward(*input, **kwargs)
> 
> ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in
> _slow_forward(self, *input, **kwargs)
>     465     def _slow_forward(self, *input, **kwargs):
>     466         input_vars = tuple(torch.autograd.function._iter_tensors(input))
> --> 467         tracing_state = torch.jit.get_tracing_state(input_vars)
>     468         if not tracing_state:
>     469             return self.forward(*input, **kwargs)
> 
> ~\Anaconda3\lib\site-packages\torch\jit\__init__.py in
> get_tracing_state(args)
>      33     if not torch._C._is_tracing(args):
>      34         return None
> ---> 35     return torch._C._get_tracing_state(args)
>      36 
>      37 
> 
> RuntimeError:
> C:\ProgramData\Miniconda3\conda-bld\pytorch_1524549877902\work\torch/csrc/jit/tracer.h:117:
> getTracingState: Assertion `var_state == state` failed.

Я не могу определить ошибку.Что может быть возможной причиной этой ошибки.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...