Буферы Pytorch, вызывающие проблемы при автоматизации c Дифференциация - PullRequest
0 голосов
/ 07 августа 2020

Я хочу реализовать пакетную нормализацию самостоятельно в pytorch, однако у меня есть проблемы с буферами для хранения текущего среднего и дисперсии. Это мой код:

class BatchNorm(nn.Module):

    def __init__(self, num_features, momentum=0.1, eps=1e-5):
        super(BatchNorm, self).__init__()
        self.num_features = num_features
        self.momentum = momentum
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(num_features, requires_grad=True))
        self.beta = nn.Parameter(torch.zeros(num_features, requires_grad=True))
        self.register_buffer('mu', torch.zeros(num_features))
        self.register_buffer('var', torch.ones(num_features))

    def forward(self, x):
        mu, var = self.update_running_stats(x)
        normalizer = torch.sqrt(var + self.eps)
        x = (x - mu) / normalizer # element-wise standardization
        x = self.alpha * x + self.beta # scale and shift
        return x

    def update_running_stats(self, x):
        batch_mu = x.sum(0) / x.size(0)
        batch_var = ((x - batch_mu)**2).sum(0) / x.size(0)
        self.mu = (1. - self.momentum) * self.mu + self.momentum * batch_mu
        self.var = (1. - self.momentum) * self.var + self.momentum * batch_var
        return self.mu, self.var

Когда я запускаю свой код, как этот, я (с вызовом backward() при потере только один раз), я получаю следующую ошибку: "RuntimeError: Попытка вернуться по графику во второй раз, но буферы уже были освобождены. При первом вызове в обратном направлении укажите "keep_graph = True".

Проблема устранена, когда я вношу следующие изменения в метод running_stats():

    mu = (1. - self.momentum) * self.mu + self.momentum * batch_mu
    var = (1. - self.momentum) * self.var + self.momentum * batch_var
    return mu, var

Но вот так буферы не обновляются, и в этом весь смысл их использования. Как я могу решить эту проблему и в чем проблема?

...