Как я могу заставить torch.jit.trace вычислить мой модуль, игнорируя хуки? - PullRequest
1 голос
/ 21 мая 2019

У меня есть модуль с хуком, и я хотел бы скомпилировать его с помощью jit:

compiled_model = torch.jit.trace(model,  torch.rand(1, 3, 256, 256))

Но я получаю ошибку:

ValueError: Modules that have hooks assigned can't be compiled

Как заставить трассировку игнорировать хуки?

1 Ответ

0 голосов
/ 21 мая 2019

Если вы хотите обойти проверку трассировки, вы можете рекурсивно удалить все хуки из вашей модели.

Это можно сделать, выполнив итерации по дочерним элементам:

from collections import OrderedDict
def remove_hooks(model):
    model._backward_hooks = OrderedDict()
    model._forward_hooks = OrderedDict()
    model._forward_pre_hooks = OrderedDict()
    for child in model.children():
        remove_hooks(child)

Затем вы можетевызвать компиляцию:

remove_hooks(model)
compiled_model = torch.jit.trace(model,  torch.rand(1, 3, 256, 256))

Но если ловушка действительно выполняет реальную работу, и вы хотите сохранить их в курсе (как было в моем случае), вы можете просто прокомментировать повышение факела в torch/jit/__init__.py строках:

if orig._backward_hooks or orig._forward_hooks or orig._forward_pre_hooks:
    raise ValueError("Modules that have hooks assigned can't be compiled")

Это сработало для меня, и мне удалось скомпилировать модель fasti.

...