Когда я хочу выполнить прямой проход с инициализированным объектом model = nn.Sequential
, я просто использую:
out = model(X)
# OR
out = model.forward(X)
Однако я попытался расширить класс Sequential
, и теперь оба этих метода вдруг требуется второй аргумент. Например, обратите внимание, что в следующем методе мой вызов self(x)
:
def train(self, trainloader, epochs):
for e in range(epochs):
for x, y in trainloader:
x = x.view(x.shape[0], -1)
self.optimizer.zero_grad()
loss = self.criterion(self(x), y) # CALL OCCURS HERE
loss.backward()
self.optimizer.step()
Теперь этот код дает мне TypeError: forward() missing 1 required positional argument: 'target'
.
Мой вопрос: Поскольку я не сделал ничего, кроме расширения класса, почему это?
Код для полного класса ниже:
class Network(nn.Sequential):
def __init__(self, layers):
super().__init__(self.init_modules(layers))
self.criterion = nn.NLLLoss()
self.optimizer = optim.Adam(self.parameters(), lr=0.003)
def init_modules(self, layers):
n_layers = len(layers)
modules = OrderedDict()
# Layer definitions for input and inner layers:
for i in range(n_layers - 2):
modules[f'fc{i}'] = nn.Linear(layers[i], layers[i+1])
modules[f'relu{i}'] = nn.ReLU()
# Definition for output layer:
modules['fc_out'] = nn.Linear(layers[-2], layers[-1])
modules['smax_out'] = nn.LogSoftmax(dim=1)
return modules
def train(self, trainloader, epochs):
for e in range(epochs):
for x, y in trainloader:
x = x.view(x.shape[0], -1)
self.optimizer.zero_grad()
loss = self.criterion(self(x), y)
loss.backward()
self.optimizer.step()
Полная трассировка стека:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-63-490e0b9eef22> in <module>
----> 1 model2.train(trainloader, 5, plot_loss=True)
<ipython-input-61-e173e5672f18> in train(self, trainloader, epochs, plot_loss)
32 x = x.view(x.shape[0], -1)
33 self.optimizer.zero_grad()
---> 34 loss = self.criterion(self(x), y)
35 loss.backward()
36 self.optimizer.step()
c:\program files\python38\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
c:\program files\python38\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
98 def forward(self, input):
99 for module in self:
--> 100 input = module(input)
101 return input
102
c:\program files\python38\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
TypeError: forward() missing 1 required positional argument: 'target'