PyTorch: Что делает декоратор @weak_script_method? - PullRequest
0 голосов
/ 16 февраля 2019

В классе torch.nn.Linear (и других классах тоже) метод forward включает в себя декоратор @weak_script_method следующим образом:

@weak_script_method
def forward(self, input):
    return F.linear(input, self.weight, self.bias)

Что делает этот декоратор? Стоит ли включать его, если я переопределяю метод forward в моем собственном подклассе модуля Linear?

1 Ответ

0 голосов
/ 26 мая 2019

Вы можете найти точное расположение декоратора , чтобы получить идею.

def weak_script_method(fn):
    weak_script_methods[fn] = {
        "rcb": createResolutionCallback(frames_up=2),
        "original_method": fn
    }
return fn

Но вам не нужно беспокоиться об этом декораторе.Этот декоратор является внутренним для JIT .

Технически метод, украшенный @weak_script_method, будет добавлен в созданный спереди словарь weak_script_methods, например:

weak_script_methods = weakref.WeakKeyDictionary() 

Это диктует методы отслеживания, чтобы избежать проблем круговой зависимости;методы, вызывающие другие методы при создании графа PyTorch.


Это действительно не имеет особого смысла, если вы не понимаете концепцию TorchScript в целом.

Идея TorchScript заключается в обучении моделей в PyTorch и экспорте моделей в другую производственную среду, отличную от Python.(читай: C ++ / C / Cuda), которые поддерживают статическую типизацию.

Команда PyTorch сделала TorchScript на ограниченной базе Python для поддержки статической типизации.По умолчанию Python является динамически типизированным языком, но с несколькими хитростями (читай: проверки) он может стать статически типизированным языком.

И поэтому функции TorchScript являются статическитипизированное подмножество Python, содержащее все встроенные в PyTorch операции Tensor.Это различие позволяет запускать код модулей TorchScript без использования интерпретатора Python.

Вы можете либо преобразовать существующие методы PyTorch в TorchScript, используя трассировку (метод torch.jit.trace()), либо создать свои TorchScripts вручную, используя @torch.jit.script decorator.

Если вы используете трассировку, в конце вы получите модуль одного класса.Вот пример:

import inspect

import torch
def foo(x, y):
    return x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

print(type(traced_foo)) #<class 'torch.jit.TopLevelTracedModule'>
print(traced_foo) #foo()
print(traced_foo.forward) #<bound method TopLevelTracedModule.forward of foo()>

lines = inspect.getsource(traced_foo.forward)
print(lines)

Выход:

<class 'torch.jit.TopLevelTracedModule'>
foo()
<bound method TopLevelTracedModule.forward of foo()>
    def forward(self, *args, **kwargs):
        return self._get_method('forward')(*args, **kwargs)

Вы можете продолжить исследование с помощью модуля проверки.Это была просто демонстрация того, как преобразовать одну функцию с помощью трассировки.

...