Проверьте это:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
print("Hi ma")
print(x)
x = F.relu(x)
return x
n = Net()
r = n(torch.tensor(-1))
print(r)
r = n.forward(torch.tensor(1)) #not planned to call directly
print(r)
out:
Hi ma
tensor(-1)
tensor(0)
Hi ma
tensor(1)
tensor(1)
Следует помнить, что forward
не следует вызывать напрямую.PyTorch сделал этот объект модуля n
вызываемым.Они реализовали callable как:
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
hook(self, input)
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
raise RuntimeError(
"forward hooks should never return any values, but '{}'"
"didn't return None".format(hook))
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
И просто n()
автоматически вызовет forward
.
В общем, __init__
определяет структуру модуля и forward()
определяет операции для одного пакета.
Эта операция может повторяться при необходимости для некоторых элементов структуры, или вы можете вызывать функции для тензоров напрямую, как мы это делали x = F.relu(x)
.
Вы получили это великолепновсе в PyTorch будет выполняться партиями (мини-пакетами), поскольку PyTorch оптимизирован для такой работы.
Это означает, что при чтении изображения вы будете читать не один, а один bs
Пакеты изображений.