Почему я должен вызывать экземпляр модуля BERT, а не метод forward? - PullRequest
0 голосов
/ 16 марта 2020

Я пытаюсь извлечь векторные представления текста с использованием BERT в libray преобразователей, и наткнулся на следующую часть документации для класса "BERTModel":

enter image description here

Кто-нибудь может объяснить это более подробно? Прямой проход имеет для меня интуитивный смысл (в конце концов, я пытаюсь получить окончательные скрытые состояния), и я не могу найти никакой дополнительной информации о том, что означает «предварительная и последующая обработка» в этом контексте.

Спасибо впереди!

1 Ответ

2 голосов
/ 16 марта 2020

Я думаю, что это всего лишь общий совет по работе с PyTorch Module. Модули transformers имеют значение nn.Module с, и для них требуется метод forward. Однако не следует звонить model.forward() вручную, а вместо этого звонить model(). Причина в том, что PyTorch делает что-то наподобие при вызове модуля. Вы можете найти, что в исходный код .

def __call__(self, *input, **kwargs):
    for hook in self._forward_pre_hooks.values():
        result = hook(self, input)
        if result is not None:
            if not isinstance(result, tuple):
                result = (result,)
            input = result
    if torch._C._get_tracing_state():
        result = self._slow_forward(*input, **kwargs)
    else:
        result = self.forward(*input, **kwargs)
    for hook in self._forward_hooks.values():
        hook_result = hook(self, input, result)
        if hook_result is not None:
            result = hook_result
    if len(self._backward_hooks) > 0:
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
            else:
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in self._backward_hooks.values():
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                grad_fn.register_hook(wrapper)
    return result

Вы увидите, что forward вызывается при необходимости.

...