torch.no_grad()
отключит информацию о градиенте для результатов операций с тензорами, для которых requires_grad
имеет значение True
.Поэтому рассмотрим следующее:
import torch
net = torch.nn.Linear(4, 3)
input_t = torch.randn(4)
with torch.no_grad():
for name, param in net.named_parameters():
print("{} {}".format(name, param.requires_grad))
out = net(input_t)
print('Output: {}'.format(out))
print('Output requires gradient: {}'.format(out.requires_grad))
print('Gradient function: {}'.format(out.grad_fn))
Это печатает
weight True
bias True
Output: tensor([-0.3311, 1.8643, 0.2933])
Output requires gradient: False
Gradient function: None
Если вы удалите with torch.no_grad()
, вы получите
weight True
bias True
Output: tensor([ 0.5776, -0.5493, -0.9229], grad_fn=<AddBackward0>)
Output requires gradient: True
Gradient function: <AddBackward0 object at 0x7febe41e3240>
Обратите внимание, что в обоих случаях модульдля параметров requires_grad
установлено значение True
, но в первом случае тензор out
не имеет связанной с ним функции градиента, тогда как во втором случае это имеет место.