Я думаю, что это всего лишь общий совет по работе с PyTorch Module
. Модули transformers
имеют значение nn.Module
с, и для них требуется метод forward
. Однако не следует звонить model.forward()
вручную, а вместо этого звонить model()
. Причина в том, что PyTorch делает что-то наподобие при вызове модуля. Вы можете найти, что в исходный код .
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
Вы увидите, что forward
вызывается при необходимости.