Я пытаюсь построить модель, используя Efficien tNet -B0. Детали модели показаны в коде ниже. Когда я пытался учиться, у меня возникла следующая ошибка.
TypeError Traceback (most recent call last)
'''
<ipython-input-17-fb3850894108> in forward(self, *x)
24 #x: bs*N x 3 x 128 x 128
25 print(x.shape) #([384, 3, 224, 224])
---> 26 x = self.enc(x)
27 #x: bs*N x C x 4 x 4
28 shape = x.shape
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/container.py in forward(self, input)
98 def forward(self, input):
99 for module in self:
--> 100 input = module(input)
101 return input
102
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
TypeError: forward() takes 1 positional argument but 2 were given
Я подозреваю, что m.children () может иметь эффект. Если кто-нибудь знает причину этой ошибки, дайте мне знать. Спасибо.
class Model(nn.Module):
def __init__(self, arch='efficientnet-b0', n=6, pre=True):
super().__init__()
m = EfficientNet.from_pretrained('efficientnet-b0')
#print(*list(m.children())
nc = m._fc.in_features
print(nc)
self.enc = nn.Sequential(*list(m.children())[:-2])
#nc = list(m.children())[-2].in_features
self.head = nn.Sequential(*list(m.children())[-2:])
self.head._fc = nn.Linear(nc, n)
#self.enc = nn.Sequential(*list(m.children()))
#print('fc_infeatures : {}'.format(nc))
#self.head = nn.Sequential(Flatten(),nn.Linear(nc,512),
# relu(),nn.BatchNorm1d(512), nn.Dropout(0.5),nn.Linear(512,n))
def forward(self, *x):
print(x[0].shape) #([32, 3, 224, 224])
shape = x[0].shape
n = len(x) # n = 12
#torch.stack直後では32*12*3*224*224(bs x N x 3 x 224 x 224)
x = torch.stack(x,1).view(-1,shape[-3],shape[-2],shape[-1])
#x: bs*N x 3 x 128 x 128
print(x.shape) #([384, 3, 224, 224])
x = self.enc(x)
#x: bs*N x C x 4 x 4
shape = x.shape
#concatenate the output for tiles into a single map
x = x.view(-1,n,shape[1],shape[2],shape[3]).permute(0,2,1,3,4).contiguous()\
.view(-1,shape[1],shape[2]*n,shape[3])
#x: bs x C x N*4 x 4
x = self.head(x)
return x