У меня есть предварительно обученная модель pytorch, вес которой мне нужно использовать в другой модели keras.
Я пытаюсь pytorch2keras репозиторий GitHub для преобразования веса Pytorch .pth в керас .h5
моя модель выглядит следующим образом:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class BaseNetwork(nn.Module):
def __init__(self, name, channels=1):
super(BaseNetwork, self).__init__()
self._name = name
self._channels = channels
def name(self):
return self._name
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform(m.weight, gain=np.sqrt(2))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
class ImageWiseNetwork(BaseNetwork):
def __init__(self, channels=1):
super(ImageWiseNetwork, self).__init__('iw' + str(channels), channels)
self.features = nn.Sequential(
# Block 1
nn.Conv2d(in_channels=12 * channels, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
# Block 2
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=2, stride=2),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1),
)
self.classifier = nn.Sequential(
nn.Linear(1 * 16 * 16, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5, inplace=True),
nn.Linear(128, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5, inplace=True),
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.Dropout(0.5, inplace=True),
nn.Linear(64, 4),
)
self.initialize_weights()
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
x = F.log_softmax(x, dim=1)
return x
создал объектную модель с model = ImageWiseNetwork()
, а затем загрузил обученные веса как:
model.load_state_dict(torch.load('Path\to\weights\weights_iw1.pth'))
, затем
input_np = np.random.uniform(0, 1, (3, 12, 512, 512))
input_var = Variable(torch.FloatTensor(input_np))and the
, а затем
from converter import pytorch_to_keras
k_model = pytorch_to_keras(model, input_var, (3, 512, 512,), verbose=True)
Я получаю следующую ошибку с трассировкой:
> > RuntimeError Traceback (most recent call last) <ipython-input-29-7c5d264109b9> in <module>()
> 1 from converter import pytorch_to_keras
> ----> 2 k_model = pytorch_to_keras(model, input_var, (3, 512, 512,), verbose=True)
>
> ~\Downloads\pytorch2keras-master\pytorch2keras-master\pytorch2keras\converter.py
> in pytorch_to_keras(model, args, input_shape, change_ordering,
> training, verbose)
> 84
> 85 with set_training(model, training):
> ---> 86 trace, torch_out = torch.jit.get_trace_graph(model, args)
> 87
> 88 if orig_state_dict_keys != _unique_state_dict(model).keys():
>
> ~\Anaconda3\lib\site-packages\torch\jit\__init__.py in
> get_trace_graph(f, args, kwargs, nderivs)
> 253 if not isinstance(args, tuple):
> 254 args = (args,)
> --> 255 return LegacyTracedModule(f, nderivs=nderivs)(*args, **kwargs)
> 256
> 257
>
> ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in
> __call__(self, *input, **kwargs)
> 487 hook(self, input)
> 488 if torch.jit._tracing:
> --> 489 result = self._slow_forward(*input, **kwargs)
> 490 else:
> 491 result = self.forward(*input, **kwargs)
>
> ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in
> _slow_forward(self, *input, **kwargs)
> 465 def _slow_forward(self, *input, **kwargs):
> 466 input_vars = tuple(torch.autograd.function._iter_tensors(input))
> --> 467 tracing_state = torch.jit.get_tracing_state(input_vars)
> 468 if not tracing_state:
> 469 return self.forward(*input, **kwargs)
>
> ~\Anaconda3\lib\site-packages\torch\jit\__init__.py in
> get_tracing_state(args)
> 33 if not torch._C._is_tracing(args):
> 34 return None
> ---> 35 return torch._C._get_tracing_state(args)
> 36
> 37
>
> RuntimeError:
> C:\ProgramData\Miniconda3\conda-bld\pytorch_1524549877902\work\torch/csrc/jit/tracer.h:117:
> getTracingState: Assertion `var_state == state` failed.
Я не могу определить ошибку.Что может быть возможной причиной этой ошибки.