forward () принимает 1 позиционный аргумент, но было дано 2 - PullRequest
0 голосов
/ 29 мая 2020

Я пытаюсь построить модель, используя Efficien tNet -B0. Детали модели показаны в коде ниже. Когда я пытался учиться, у меня возникла следующая ошибка.

TypeError                                 Traceback (most recent call last)
'''
<ipython-input-17-fb3850894108> in forward(self, *x)
     24         #x: bs*N x 3 x 128 x 128
     25         print(x.shape)     #([384, 3, 224, 224])
---> 26         x = self.enc(x)
     27         #x: bs*N x C x 4 x 4
     28         shape = x.shape

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/container.py in forward(self, input)
     98     def forward(self, input):
     99         for module in self:
--> 100             input = module(input)
    101         return input
    102 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

TypeError: forward() takes 1 positional argument but 2 were given

Я подозреваю, что m.children () может иметь эффект. Если кто-нибудь знает причину этой ошибки, дайте мне знать. Спасибо.

class Model(nn.Module):
    def __init__(self, arch='efficientnet-b0', n=6, pre=True):
        super().__init__()
        m = EfficientNet.from_pretrained('efficientnet-b0')
        #print(*list(m.children())
        nc = m._fc.in_features
        print(nc)
        self.enc = nn.Sequential(*list(m.children())[:-2])
        #nc = list(m.children())[-2].in_features
        self.head = nn.Sequential(*list(m.children())[-2:])
        self.head._fc = nn.Linear(nc, n)
        #self.enc = nn.Sequential(*list(m.children()))
        #print('fc_infeatures : {}'.format(nc))
        #self.head = nn.Sequential(Flatten(),nn.Linear(nc,512),
        #                    relu(),nn.BatchNorm1d(512), nn.Dropout(0.5),nn.Linear(512,n))


    def forward(self, *x):
        print(x[0].shape)  #([32, 3, 224, 224])
        shape = x[0].shape
        n = len(x)           # n = 12
        #torch.stack直後では32*12*3*224*224(bs x N x 3 x 224 x 224)
        x = torch.stack(x,1).view(-1,shape[-3],shape[-2],shape[-1])
        #x: bs*N x 3 x 128 x 128
        print(x.shape)     #([384, 3, 224, 224])
        x = self.enc(x)
        #x: bs*N x C x 4 x 4
        shape = x.shape
        #concatenate the output for tiles into a single map
        x = x.view(-1,n,shape[1],shape[2],shape[3]).permute(0,2,1,3,4).contiguous()\
          .view(-1,shape[1],shape[2]*n,shape[3])
        #x: bs x C x N*4 x 4
        x = self.head(x)
        return x

...