Я использую PyTorch 1.4, и мне нужно экспортировать модель со свертками внутри al oop в forward
:
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
def forward(self, x):
for i in range(5):
conv = torch.nn.Conv1d(1, 1, 2*i+3)
x = torch.nn.Relu()(conv(x))
return x
torch.jit.script(MyCell())
Это дает следующую ошибку:
RuntimeError:
Arguments for call are not valid.
The following variants are available:
_single(float[1] x) -> (float[]):
Expected a value of type 'List[float]' for argument 'x' but instead found type 'Tensor'.
_single(int[1] x) -> (int[]):
Expected a value of type 'List[int]' for argument 'x' but instead found type 'Tensor'.
The original call is:
File "***/torch/nn/modules/conv.py", line 187
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros'):
kernel_size = _single(kernel_size)
~~~~~~~ <--- HERE
stride = _single(stride)
padding = _single(padding)
'Conv1d.__init__' is being compiled since it was called from 'Conv1d'
File "***", line ***
def forward(self, x):
for _ in range(5):
conv = torch.nn.Conv1d(1, 1, 2*i+3)
~~~~~~~~~~~~~~~ <--- HERE
x = torch.nn.Relu()(conv(x))
return x
'Conv1d' is being compiled since it was called from 'MyCell.forward'
File "***", line ***
def forward(self, x, h):
for _ in range(5):
conv = torch.nn.Conv1d(1, 1, 2*i+3)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
x = torch.nn.Relu()(conv(x))
return x
Я также пытался предварительно определить conv
и затем поместить их в список внутри __init__
, но TorchScript не допускает такой тип:
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.conv = [torch.nn.Conv1d(1, 1, 2*i+3) for i in range(5)]
def forward(self, x):
for i in range(len(self.conv)):
x = torch.nn.Relu()(self.conv[i](x))
return x
torch.jit.script(MyCell())
Вместо этого он дает:
RuntimeError:
Module 'MyCell' has no attribute 'conv' (This attribute exists on the Python module, but we failed to convert Python type: 'list' to a TorchScript type.):
File "***", line ***
def forward(self, x):
for i in range(len(self.conv)):
~~~~~~~~~ <--- HERE
x = torch.nn.Relu()(self.conv[i](x))
return x
Так как же экспортировать этот модуль? Справочная информация: я экспортирую Плотные сети смешанного масштаба (источник) в TorchScript; в то время как nn.Sequential
может работать для этого упрощенного случая, практически мне нужно сверяться со всеми историческими выходами свертки в каждой итерации, что больше, чем цепочка слоев.