Можно ли использовать пользовательский torch.autograd.Function
в объекте nn.Sequential
или я должен явно использовать объект nn.Module
с функцией forward. В частности, я пытаюсь реализовать разреженный авто-кодер, и мне нужно добавить L1-расстояние кода (скрытое представление) к потере. Я определил пользовательский torch.autograd.Function
L1Penalty ниже, а затем попытался использовать его внутри nn.Sequential
объекта, как показано ниже. Однако при запуске я получил ошибку TypeError: __main__.L1Penalty is not a Module subclass
Как я могу решить эту проблему?
class L1Penalty(torch.autograd.Function):
@staticmethod
def forward(ctx, input, l1weight = 0.1):
ctx.save_for_backward(input)
ctx.l1weight = l1weight
return input, None
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_variables
grad_input = input.clone().sign().mul(ctx.l1weight)
grad_input+=grad_output
return grad_input
model = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 6),
nn.ReLU(),
# sparsity
L1Penalty(),
nn.Linear(6, 10),
nn.ReLU(),
nn.Linear(10, 10),
nn.ReLU()
).to(device)