pytorch отслеживает градиенты через вычислительный график тензоров , а не через функции. Пока ваши тензоры обладают свойством requires_grad=True
и их grad
не None
, вы можете делать (почти) все, что вам нравится, и при этом иметь возможность делать обратное.
Пока вы используете операции pytorch (например, перечисленные в здесь и здесь ), все будет в порядке.
Для получения дополнительной информации см. this .
Например (взято из реализации VGG torchvision ):
class VGG(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
# ...
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1) # <-- what you were asking about
x = self.classifier(x)
return x
A Более сложный пример можно увидеть в реализации в Torchvision Re sNet:
class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
# ...
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None: # <-- conditional execution!
identity = self.downsample(x)
out += identity # <-- inplace operations
out = self.relu(out)
return out