Понимает ли PyTorch loss () и обратное распространение лямбда-слои? - PullRequest
0 голосов
/ 20 января 2020

Я работал с моделью resnet56 из кода, приведенного здесь: https://github.com/akamaster/pytorch_resnet_cifar10/blob/master/resnet.py.

Я заметил, что реализация отличается от многих других доступных примеров Re sNet в Интернете, и мне было интересно, может ли алгоритм обратного распространения PyTorch, использующий loss (), учесть лямбда-слой и ярлык в предоставленном коде ,

Если это так, может ли кто-нибудь дать представление о том, как PyTorch может интерпретировать лямбда-слой для обратного распространения (т. Е. Как PyTorch узнает, как дифференцировать операции этого слоя)?

PS Мне также пришлось изменить код, чтобы он соответствовал моему собственному сценарию использования, и похоже, что моя собственная реализация с параметром == 'A' не дает хороших результатов. Это может быть просто потому, что опция == 'B', которая использует сверточные слои вместо заполнения, лучше для моих данных.

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                top = (int) ((self.expansion*planes - in_planes) / 2)
                bot = (self.expansion*planes - in_planes) - top
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::stride, ::stride], (0, 0, 0, 0, top, bot), "constant", 0))

1 Ответ

0 голосов
/ 24 января 2020

"Мне было интересно, может ли алгоритм обратного распространения PyTorch, использующий loss (), учитывать лямбда-слой и ярлык в предоставленном коде."

У PyTorch нет проблем с обратным распространением через лямбда-функции , Ваш LambdaLayer просто определяет прямой проход модуля как оценку лямбда-функции, поэтому ваш вопрос сводится к тому, может ли PyTorch распространяться через лямбда-функции.

"Если это так, может ли кто-нибудь дать представление о том, как PyTorch может интерпретировать лямбда-слой для обратного распространения (то есть как PyTorch знает, как различать операции этого слоя)?"

Лямбда-функция выполняет функцию torch.nn.functional.Pad для x, которую мы можем распространять, потому что она имеет определенную функцию backwards ().

PyTorch обрабатывает лямбда-функции таким же образом, как инструмент автоопределения, такой как PyTorch, обрабатывает любую функцию: он разбивает ее на примитивные операции и использует правила дифференцирования для каждой примитивной операции для создания производной всего вычисления.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...