Существует подход «ленивый человек»: вы можете просто подключить nn.BatchNorm2d
в качестве самого первого слоя вашей сети. С соответствующими momentum
и track_running_stats=True
этот слой будет оценивать среднее значение и дисперсию ваших данных.
В качестве альтернативы, вы можете вычислить среднее значение и дисперсию, используя
mu = torch.zeros((3,), dtype=torch.float)
sig = torch.zeros((3,), dtype=torch.float)
n = 0
with torch.no_grad():
for num, data in enumerate(trainset):
imgs, _ = data
mu += torch.sum(imgs, dim=(0, 2, 3))
sig += torch.sum(imgs**2, dim=(0, 2, 3))
n += imgs.numel() // imgs.shape[0]
n = float(n)
mu = mu / n # mean
sig = sig / n - (mu ** 2)