Может ли моя функция форварда PyTorch выполнять дополнительные операции? - PullRequest
2 голосов
/ 04 марта 2020

Обычно функция forward объединяет несколько слоев и возвращает выходные данные последнего. Могу ли я выполнить дополнительную обработку после последнего слоя перед возвратом? Например, некоторое скалярное умножение и изменение формы с помощью .view?

Я знаю, что автоград может каким-то образом вычислять градиенты. Так что я не знаю, если моя дополнительная обработка как-то облажается. Спасибо.

1 Ответ

4 голосов
/ 04 марта 2020

отслеживает градиенты через вычислительный график тензоров , а не через функции. Пока ваши тензоры обладают свойством requires_grad=True и их grad не None, вы можете делать (почти) все, что вам нравится, и при этом иметь возможность делать обратное.
Пока вы используете операции pytorch (например, перечисленные в здесь и здесь ), все будет в порядке.

Для получения дополнительной информации см. this .

Например (взято из реализации VGG torchvision ):

class VGG(nn.Module):

    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        #  ...

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)  # <-- what you were asking about
        x = self.classifier(x)
        return x

A Более сложный пример можно увидеть в реализации в Torchvision Re sNet:

class Bottleneck(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        # ...

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:    # <-- conditional execution!
            identity = self.downsample(x)

        out += identity  # <-- inplace operations
        out = self.relu(out)

        return out
...