Я хочу реализовать мета-обучение с помощью pytorch DistributedDataParallel. Однако есть две проблемы:
После установки loss.backward(retain_graph=True, create_graph=True)
произошла ошибка, сказанная RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Если установлено только retain_graph=True
, все работает нормально.
Когда я установил retain_graph=True
и запустил код, я обнаружил, что градиент второго порядка изменяется в зависимости от количества рангов, в то время как потери не меняются. Я обнаружил, что ключевая проблема может исходить от get_update_network()
. Однако я могу понять, как это произошло и как это исправить.
например,
DataParallel
grad1 : tensor(3.5100, device='cuda:0')
net2 : tensor([[ -3.9810, -13.5981, -4.1402, -36.6334]], device='cuda:0',
grad_fn=<SubBackward0>)
grad2 : tensor(7.0200, device='cuda:0')
loss1 : -8.781850814819336, loss2 : -1550.81201171875
DistributedDataParallel
rank 1
grad1 : tensor(3.5100, device='cuda:0')
net2 : tensor([[ -3.9810, -13.5981, -4.1402, -36.6334]], device='cuda:0',
grad_fn=<SubBackward0>)
grad2 : tensor(7.0200, device='cuda:0')
loss1 : -8.781850814819336, loss2 : -1550.81201171875
rank 2
grad1 : tensor(3.5100, device='cuda:0')
net2 : tensor([[ -3.9810, -13.5981, -4.1402, -36.6334]], device='cuda:0',
grad_fn=<SubBackward0>)
grad2 : tensor(6.5200, device='cuda:0')
loss1 : -8.781851768493652, loss2 : -1550.81201171875
rank 4
grad1 : tensor(3.5100, device='cuda:0')
net2 : tensor([[ -3.9810, -13.5981, -4.1402, -36.6334]], device='cuda:0',
grad_fn=<SubBackward0>)
grad2 : tensor(5.5200, device='cuda:0')
loss1 : -8.781850814819336, loss2 : -1550.81201171875
Демонстрационный код соблюдается:
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
torch.manual_seed(1234)
def put_theta(model, theta):
def k_param_fn(tmp_model, name=None):
if len(tmp_model._modules) != 0:
for (k, v) in tmp_model._modules.items():
if name is None:
k_param_fn(v, name=str(k))
else:
k_param_fn(v, name=str(name + '.' + k))
else:
for (k, v) in tmp_model._parameters.items():
if not isinstance(v, torch.Tensor):
continue
tmp_model._parameters[k] = theta[str(name + '.' + k)]
k_param_fn(model)
return model
def get_updated_network(old, new, lr, load=False):
updated_theta = {}
state_dicts = old.state_dict()
param_dicts = dict(old.named_parameters())
# print(param_dicts['module.backbone.conv1.0.weight'].grad.sum(), '\n')
for i, (k, v) in enumerate(state_dicts.items()):
if k in param_dicts.keys() and param_dicts[k].grad is not None:
updated_theta[k] = param_dicts[k] - lr * param_dicts[k].grad
else:
updated_theta[k] = state_dicts[k]
if load:
new.load_state_dict(updated_theta)
else:
new = put_theta(new, updated_theta)
return new
class Datax(Dataset):
def __getitem__(self, item):
data = [0.01, 10, 0.4, 33]
data = torch.tensor(data).view(4).float() + item
# data = torch.randn((4,))
return data
def __len__(self):
return 100
class Net2(nn.Module):
def __init__(self):
super(Net2, self).__init__()
self.conv = nn.Sequential(
nn.Linear(4, 1, bias=False),
)
def forward(self, x):
x = self.conv(x)
return x.mean()
def train(rank, world_size):
if rank == -1:
net = nn.DataParallel(Net2()).cuda()
net2 = nn.DataParallel(Net2()).cuda()
opt1 = torch.optim.SGD(net.parameters(), lr=1e-3)
dataset = Datax()
loader = DataLoader(dataset, batch_size=8, shuffle=False, pin_memory=True, sampler=None)
else:
net = DistributedDataParallel(Net2().cuda(), device_ids=[rank], find_unused_parameters=True)
net2 = DistributedDataParallel(Net2().cuda(), device_ids=[rank], find_unused_parameters=True)
opt1 = torch.optim.SGD(net.parameters(), lr=1e-3)
dataset = Datax()
sampler = torch.utils.data.distributed.DistributedSampler(dataset, rank=rank, shuffle=False)
loader = DataLoader(dataset, batch_size=8//world_size, shuffle=False, pin_memory=True, sampler=sampler)
for i, data in enumerate(loader):
data = data.cuda()
l1 = net(data).mean()
opt1.zero_grad()
# when set the create_graph=True, error occurs in the second loop
l1.backward(retain_graph=True)
if rank <= 0:
# first gradients and losses are all the same for all ranks
print('grad1 : ', net.module.conv[0].weight.grad[0, 0])
# net2 = net # this line works fine, so the error comes from this line
# However, all ranks outputs same architecture and same loss, but the grad is different.
net2 = get_updated_network(net, net2, 1)
if rank <= 0:
for k, v in net2.named_parameters():
print('net2 : ', v)
l2 = net2(data).mean()
l2.backward()
if rank != -1:
dist.all_reduce(l1), dist.all_reduce(l2)
l1 = l1 / world_size
l2 = l2 / world_size
if rank <= 0:
print('grad2 : ', net.module.conv[0].weight.grad[0, 0])
print('loss1 : {}, loss2 : {}'.format(l1.item(), l2.item()))
opt1.step()
if i == 0:
break
def dist_train(proc, ngpus_per_node, args):
backend = 'nccl'
url = 'tcp://127.0.0.1:23458'
world_size = args
dist.init_process_group(backend=backend, init_method=url, world_size=world_size, rank=proc)
torch.cuda.set_device(proc)
train(proc, world_size)
if __name__ == '__main__':
train(-1, 4)
for i in [1, 2, 4]:
print('\n')
ranks = i
torch.manual_seed(1234)
mp.spawn(dist_train, nprocs=ranks, args=(ranks, ranks))