Я пытался реализовать пользовательскую функцию нормализации партии, чтобы ее можно было расширить до версии с несколькими графическими процессорами, в частности, до модуля DataParallel в Pytorch. Пользовательский пакетный режим работает нормально при использовании 1 графического процессора, но при расширении до 2 или более, среднее значение и дисперсия работают в функции пересылки, но когда они возвращаются из сети, среднее значение и дисперсия повторно инициализируются до 0 и 1.
В torch.nn.DataParallel упоминается в раздел предупреждения: «В каждой пересылке модуль реплицируется на каждое устройство, поэтому любые обновления работающего модуля в пересылке будут потеряны. Например, если модуль имеет атрибут счетчика, который увеличивается на каждую пересылку, он всегда будет оставаться на начальное значение, поскольку обновление выполняется для реплик, которые уничтожаются после пересылки. " Но я не совсем уверен, как сохранить среднее значение и отклонение от устройства по умолчанию.
Я предоставил код с результатом, полученным во время обучения нескольких GPU. В этом коде используется предоставленный Batchnorm здесь .
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torch.nn.parameter import Parameter
class ptrblck_BatchNorm2d(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(ptrblck_BatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
def forward(self, input):
self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# calculate running estimates
if self.training:
mean = input.mean([0, 2, 3])
# use biased var in train
var = input.var([0, 2, 3], unbiased=False)
n = input.numel() / input.size(1)
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# update running_var with unbiased var
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
else:
mean = self.running_mean
var = self.running_var
input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
if self.affine:
input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
return input
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.bn1 = ptrblck_BatchNorm2d(64)
print("==> printing bn1 mean when init")
print(self.bn1.running_mean)
print("==> printing bn1 when init")
print(self.bn1.running_mean)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.classifier = nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.pool(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
print("======================================================")
print("==> printing bn1 running mean from NET during forward")
print(net.module.bn1.running_mean)
print("==> printing bn1 running mean from SELF. during forward")
print(self.bn1.running_mean)
print("==> printing bn1 running var from NET during forward")
print(net.module.bn1.running_var)
print("==> printing bn1 running mean from SELF. during forward")
print(self.bn1.running_var)
return x
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# Model
print('==> Building model..')
net = net()
net = torch.nn.DataParallel(net).cuda()
print('Number of GPU {}'.format(torch.cuda.device_count()))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.cuda(), targets.cuda()
outputs = net(inputs)
loss = criterion(outputs, targets)
print("====================================================")
print("==> printing bn1 running mean FROM net after forward")
print(net.module.bn1.running_mean)
print("==> printing bn1 running var FROM net after forward")
print(net.module.bn1.running_var)
break
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# train_loss += loss.item()
# _, predicted = outputs.max(1)
# total += targets.size(0)
# correct += predicted.eq(targets).sum().item()
# break
for epoch in range(0, 1):
train(epoch)
Результат:
==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Building model..
==> printing bn1 mean when init
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
==> printing bn1 when init
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Number of GPU 2
Epoch: 0
======================================================
==> printing bn1 running mean from NET during forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([ 0.0053, 0.0010, -0.0077, -0.0290, 0.0241, 0.0258, -0.0048, 0.0151,
-0.0133, 0.0080, 0.0197, -0.0042, -0.0188, 0.0233, 0.0310, -0.0230,
-0.0133, 0.0222, 0.0119, -0.0042, -0.0220, -0.0169, -0.0342, -0.0025,
0.0338, -0.0070, 0.0202, 0.0050, 0.0108, 0.0008, 0.0363, 0.0347,
-0.0106, 0.0082, 0.0128, 0.0074, 0.0111, -0.0030, -0.0089, 0.0070,
-0.0262, -0.0029, 0.0053, -0.0136, -0.0183, 0.0045, -0.0014, -0.0221,
0.0132, 0.0064, 0.0388, -0.0220, -0.0008, 0.0400, -0.0187, 0.0397,
-0.0131, -0.0176, 0.0035, 0.0055, -0.0270, 0.0066, -0.0149, 0.0135],
device='cuda:0')
==> printing bn1 running var from NET during forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9665, 0.9073, 0.9220, 1.0947, 1.0687, 0.9624, 0.9252, 0.9131, 0.9066,
0.9536, 0.9258, 0.9203, 1.0359, 0.9690, 1.1066, 1.0636, 0.9135, 0.9644,
0.9373, 0.9846, 0.9696, 0.9454, 1.0459, 0.9245, 0.9778, 0.9709, 0.9352,
0.9995, 0.9657, 0.9510, 1.0943, 1.0171, 0.9298, 1.0747, 0.9341, 0.9635,
0.9978, 0.9303, 0.9261, 0.9137, 0.9569, 1.0066, 1.0463, 0.9955, 0.9621,
0.9172, 0.9836, 0.9817, 0.9086, 0.9576, 1.0905, 0.9861, 0.9661, 1.1773,
0.9345, 1.0904, 0.9133, 1.0660, 0.9164, 0.9058, 0.9446, 0.9225, 1.0914,
0.9292], device='cuda:0')
======================================================
==> printing bn1 running mean from NET during forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([-0.0020, 0.0002, -0.0103, -0.0426, 0.0386, 0.0311, -0.0059, 0.0151,
-0.0140, 0.0145, 0.0218, -0.0029, -0.0281, 0.0284, 0.0449, -0.0329,
-0.0107, 0.0278, 0.0135, -0.0123, -0.0260, -0.0214, -0.0423, -0.0035,
0.0410, -0.0097, 0.0276, 0.0102, 0.0197, -0.0001, 0.0483, 0.0451,
-0.0078, 0.0190, 0.0135, -0.0004, 0.0196, -0.0028, -0.0140, 0.0070,
-0.0332, -0.0110, 0.0151, -0.0210, -0.0226, 0.0074, -0.0088, -0.0314,
0.0125, -0.0003, 0.0505, -0.0312, 0.0086, 0.0544, -0.0245, 0.0528,
-0.0086, -0.0290, 0.0063, 0.0042, -0.0339, 0.0061, -0.0277, 0.0092],
device='cuda:1')
==> printing bn1 running var from NET during forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9665, 0.9072, 0.9211, 1.0999, 1.0714, 0.9610, 0.9209, 0.9125, 0.9063,
0.9553, 0.9260, 0.9189, 1.0386, 0.9706, 1.1139, 1.0610, 0.9121, 0.9660,
0.9366, 0.9886, 0.9683, 0.9454, 1.0511, 0.9227, 0.9792, 0.9704, 0.9330,
0.9989, 0.9657, 0.9476, 1.1008, 1.0191, 0.9294, 1.0814, 0.9320, 0.9642,
1.0006, 0.9287, 0.9254, 0.9128, 0.9559, 1.0100, 1.0521, 0.9972, 0.9621,
0.9168, 0.9849, 0.9803, 0.9083, 0.9556, 1.0946, 0.9865, 0.9651, 1.1880,
0.9330, 1.0959, 0.9116, 1.0706, 0.9149, 0.9057, 0.9450, 0.9215, 1.0972,
0.9261], device='cuda:1')
====================================================
==> printing bn1 running mean FROM net after forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
device='cuda:0')
==> printing bn1 running var FROM net after forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
Как я могу убедиться, что используются оценки работы устройства по умолчанию? В настоящее время я не работаю над синхронизированным Batchnorm.