Я хочу реализовать пакетную нормализацию самостоятельно в pytorch, однако у меня есть проблемы с буферами для хранения текущего среднего и дисперсии. Это мой код:
class BatchNorm(nn.Module):
def __init__(self, num_features, momentum=0.1, eps=1e-5):
super(BatchNorm, self).__init__()
self.num_features = num_features
self.momentum = momentum
self.eps = eps
self.alpha = nn.Parameter(torch.ones(num_features, requires_grad=True))
self.beta = nn.Parameter(torch.zeros(num_features, requires_grad=True))
self.register_buffer('mu', torch.zeros(num_features))
self.register_buffer('var', torch.ones(num_features))
def forward(self, x):
mu, var = self.update_running_stats(x)
normalizer = torch.sqrt(var + self.eps)
x = (x - mu) / normalizer # element-wise standardization
x = self.alpha * x + self.beta # scale and shift
return x
def update_running_stats(self, x):
batch_mu = x.sum(0) / x.size(0)
batch_var = ((x - batch_mu)**2).sum(0) / x.size(0)
self.mu = (1. - self.momentum) * self.mu + self.momentum * batch_mu
self.var = (1. - self.momentum) * self.var + self.momentum * batch_var
return self.mu, self.var
Когда я запускаю свой код, как этот, я (с вызовом backward()
при потере только один раз), я получаю следующую ошибку: "RuntimeError: Попытка вернуться по графику во второй раз, но буферы уже были освобождены. При первом вызове в обратном направлении укажите "keep_graph = True".
Проблема устранена, когда я вношу следующие изменения в метод running_stats()
:
mu = (1. - self.momentum) * self.mu + self.momentum * batch_mu
var = (1. - self.momentum) * self.var + self.momentum * batch_var
return mu, var
Но вот так буферы не обновляются, и в этом весь смысл их использования. Как я могу решить эту проблему и в чем проблема?