Вы можете найти точное расположение декоратора , чтобы получить идею.
def weak_script_method(fn):
weak_script_methods[fn] = {
"rcb": createResolutionCallback(frames_up=2),
"original_method": fn
}
return fn
Но вам не нужно беспокоиться об этом декораторе.Этот декоратор является внутренним для JIT .
Технически метод, украшенный @weak_script_method
, будет добавлен в созданный спереди словарь weak_script_methods
, например:
weak_script_methods = weakref.WeakKeyDictionary()
Это диктует методы отслеживания, чтобы избежать проблем круговой зависимости;методы, вызывающие другие методы при создании графа PyTorch.
Это действительно не имеет особого смысла, если вы не понимаете концепцию TorchScript в целом.
Идея TorchScript заключается в обучении моделей в PyTorch и экспорте моделей в другую производственную среду, отличную от Python.(читай: C ++ / C / Cuda), которые поддерживают статическую типизацию.
Команда PyTorch сделала TorchScript на ограниченной базе Python для поддержки статической типизации.По умолчанию Python является динамически типизированным языком, но с несколькими хитростями (читай: проверки) он может стать статически типизированным языком.
И поэтому функции TorchScript являются статическитипизированное подмножество Python, содержащее все встроенные в PyTorch операции Tensor.Это различие позволяет запускать код модулей TorchScript без использования интерпретатора Python.
Вы можете либо преобразовать существующие методы PyTorch в TorchScript, используя трассировку (метод torch.jit.trace()
), либо создать свои TorchScripts вручную, используя @torch.jit.script
decorator.
Если вы используете трассировку, в конце вы получите модуль одного класса.Вот пример:
import inspect
import torch
def foo(x, y):
return x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
print(type(traced_foo)) #<class 'torch.jit.TopLevelTracedModule'>
print(traced_foo) #foo()
print(traced_foo.forward) #<bound method TopLevelTracedModule.forward of foo()>
lines = inspect.getsource(traced_foo.forward)
print(lines)
Выход:
<class 'torch.jit.TopLevelTracedModule'>
foo()
<bound method TopLevelTracedModule.forward of foo()>
def forward(self, *args, **kwargs):
return self._get_method('forward')(*args, **kwargs)
Вы можете продолжить исследование с помощью модуля проверки.Это была просто демонстрация того, как преобразовать одну функцию с помощью трассировки.