Если вы просто не хотите вычислять градиенты для вашего TransformationFunction
, проще всего отключить вычисление градиента для всех параметров, участвующих в этом вычислении, установив флаг requires_grad
на False
.
- Без учета подграфов в обратном направлении:
Если в операции есть один вход, требующийградиент, для его вывода также потребуется градиент.И наоборот, только если все входные данные не требуют градиента, выходные данные также не будут требовать его.Вычисления в обратном направлении никогда не выполняются в подграфах, где всем тензорам не требуются градиенты.
Это особенно полезно, когда вы хотите заморозить часть своей модели или заранее знаете, что не собираетесьиспользовать градиенты по некоторым параметрам.Например, если вы хотите выполнить предварительную настройку предварительно обученного CNN, достаточно переключить флаги requires_grad
в замороженной базе, и никакие промежуточные буферы не будут сохранены, пока вычисление не дойдет до последнего уровня, где аффинное преобразование будет использовать веса, которыетребуют градиента, и выход сети также будет требовать их.
Вот небольшой пример, который может сделать это:
import torch
import torch.nn as nn
# define layers
normal_layer = nn.Linear(5, 5)
TransformationFunction = nn.Linear(5, 5)
# disable gradient computation for parameters of TransformationFunction
# here weight and bias
TransformationFunction.weight.requires_grad = False
TransformationFunction.bias.requires_grad = False
# input
inp = torch.rand(1, 5)
# do computation
out = normal_layer(inp)
out = TransformationFunction(out)
# loss
loss = torch.sum(out)
# backward
loss.backward()
# gradient for l1
print('Gradients for "normal_layer"', normal_layer.weight.grad, normal_layer.bias.grad)
# gradient for l2
print('Gradients for "TransformationFunction"', TransformationFunction.weight.grad, TransformationFunction.bias.grad)
Вывод:
Gradients for "normal_layer" tensor([[0.1607, 0.0215, 0.0192, 0.2595, 0.0811],
[0.0788, 0.0105, 0.0094, 0.1272, 0.0398],
[0.1552, 0.0207, 0.0186, 0.2507, 0.0784],
[0.1541, 0.0206, 0.0184, 0.2489, 0.0778],
[0.2945, 0.0393, 0.0352, 0.4756, 0.1486]]) tensor([0.2975, 0.1458, 0.2874, 0.2853, 0.5452])
Gradients for "TransformationFunction" None None
Надеюсь, это то, что вы искали, если нет, пожалуйста, отредактируйте свой вопрос более подробно!