Как обернуть torch.autograd.Fucntion.apply, чтобы модуль можно было распечатать? - PullRequest
0 голосов
/ 13 июля 2020

Я определяю новую функцию активации, и я написал прямую и обратную реализацию, следуя официальному примеру по этой ссылке: https://pytorch.org/docs/stable/autograd.html Вот как выглядит моя активация:

from torch.autograd import Function

class my_activation(Function):
    """my unique activation"""

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)  # save input for backward pass
        res = ...do something with input...
        return res

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = None
        input, = ctx.saved_tensors  # restore input from context
        grad_input = ...do something with input and grad_output...
        return grad_input

Так что в моей модели я могу создать экземпляр своей активации с помощью операции .apply, так что моя модель будет выглядеть примерно так:

class MyNet(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(...)
        self.activation = my_activation.apply

    def forward(self, x):
        x = self.conv1(x)
        x = self.activation(x)
        return x

Этот код работает нормально, но я хочу иметь возможность инициализировать его как обычный модуль pytorch и вижу, что мой слой активации распечатывается, когда я печатаю (модель) В частности, у меня есть 2 вопроса, я чувствую, что это можно сделать одним способом:

  1. Как мне обернуть вверх my_activation, чтобы я мог инициализировать его так же, как nn.Conv2d или nn.ReLU:
self.conv1 = nn.Conv2d(...)
self.relu = nn.ReLU(...)
self.my_activation = my_activation() # instead of self.my_activation = my_activation.apply
Как мне распечатать my_activation, когда я печатаю M yNet?
model = MyNet(...)
print(model) # I want to be able to see my model info including my self-defined activation I mentioned above, currently I am able to see all other parts of the models I use from pytorch built-in modules like nn.ReLU and nn.Conv2d in a very structured manner

Я подозреваю, что это связано с тем, что активация не наследует torch.nn.Module, как torch.nn .ReLU? Моя попытка, для моего первого вопроса, я попытался написать такую ​​функцию-оболочку:

class my_activation(object):
    def __init__(self):
        self.act = _my_activation.apply # I changed my previous my_activation class name to _my_activation

    def __call__(self, *args, **kwargs):
        return self.tri

Итак, теперь я могу

self.act = my_activation()()

Но это все еще странно ...

Спасибо, что прочитали этот длинный пост, заранее благодарю за вашу помощь!

...