Мета-обучение с pytorch DistributedDataParallel, результат изменяется при изменении ранга? - PullRequest
0 голосов
/ 04 мая 2020

Я хочу реализовать мета-обучение с помощью pytorch DistributedDataParallel. Однако есть две проблемы:

  1. После установки loss.backward(retain_graph=True, create_graph=True) произошла ошибка, сказанная RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time. Если установлено только retain_graph=True, все работает нормально.

  2. Когда я установил retain_graph=True и запустил код, я обнаружил, что градиент второго порядка изменяется в зависимости от количества рангов, в то время как потери не меняются. Я обнаружил, что ключевая проблема может исходить от get_update_network(). Однако я могу понять, как это произошло и как это исправить.

например,

DataParallel

grad1 :  tensor(3.5100, device='cuda:0')
net2 :  tensor([[ -3.9810, -13.5981,  -4.1402, -36.6334]], device='cuda:0',
       grad_fn=<SubBackward0>)
grad2 :  tensor(7.0200, device='cuda:0')
loss1 : -8.781850814819336, loss2 : -1550.81201171875

DistributedDataParallel

rank 1 

grad1 :  tensor(3.5100, device='cuda:0')
net2 :  tensor([[ -3.9810, -13.5981,  -4.1402, -36.6334]], device='cuda:0',
       grad_fn=<SubBackward0>)
grad2 :  tensor(7.0200, device='cuda:0')
loss1 : -8.781850814819336, loss2 : -1550.81201171875

rank 2 

grad1 :  tensor(3.5100, device='cuda:0')
net2 :  tensor([[ -3.9810, -13.5981,  -4.1402, -36.6334]], device='cuda:0',
       grad_fn=<SubBackward0>)
grad2 :  tensor(6.5200, device='cuda:0')
loss1 : -8.781851768493652, loss2 : -1550.81201171875

rank 4 

grad1 :  tensor(3.5100, device='cuda:0')
net2 :  tensor([[ -3.9810, -13.5981,  -4.1402, -36.6334]], device='cuda:0',
       grad_fn=<SubBackward0>)
grad2 :  tensor(5.5200, device='cuda:0')
loss1 : -8.781850814819336, loss2 : -1550.81201171875

Демонстрационный код соблюдается:

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

torch.manual_seed(1234)


def put_theta(model, theta):
    def k_param_fn(tmp_model, name=None):
        if len(tmp_model._modules) != 0:
            for (k, v) in tmp_model._modules.items():
                if name is None:
                    k_param_fn(v, name=str(k))
                else:
                    k_param_fn(v, name=str(name + '.' + k))
        else:
            for (k, v) in tmp_model._parameters.items():
                if not isinstance(v, torch.Tensor):
                    continue
                tmp_model._parameters[k] = theta[str(name + '.' + k)]

    k_param_fn(model)
    return model


def get_updated_network(old, new, lr, load=False):
    updated_theta = {}
    state_dicts = old.state_dict()
    param_dicts = dict(old.named_parameters())
    # print(param_dicts['module.backbone.conv1.0.weight'].grad.sum(), '\n')
    for i, (k, v) in enumerate(state_dicts.items()):
        if k in param_dicts.keys() and param_dicts[k].grad is not None:
            updated_theta[k] = param_dicts[k] - lr * param_dicts[k].grad
        else:
            updated_theta[k] = state_dicts[k]
    if load:
        new.load_state_dict(updated_theta)
    else:
        new = put_theta(new, updated_theta)
    return new


class Datax(Dataset):
    def __getitem__(self, item):
        data = [0.01, 10, 0.4, 33]
        data = torch.tensor(data).view(4).float() + item
        # data = torch.randn((4,))
        return data

    def __len__(self):
        return 100


class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.conv = nn.Sequential(
            nn.Linear(4, 1, bias=False),
        )

    def forward(self, x):
        x = self.conv(x)
        return x.mean()


def train(rank, world_size):
    if rank == -1:
        net = nn.DataParallel(Net2()).cuda()
        net2 = nn.DataParallel(Net2()).cuda()
        opt1 = torch.optim.SGD(net.parameters(), lr=1e-3)
        dataset = Datax()
        loader = DataLoader(dataset, batch_size=8, shuffle=False, pin_memory=True, sampler=None)
    else:
        net = DistributedDataParallel(Net2().cuda(), device_ids=[rank], find_unused_parameters=True)
        net2 = DistributedDataParallel(Net2().cuda(), device_ids=[rank], find_unused_parameters=True)
        opt1 = torch.optim.SGD(net.parameters(), lr=1e-3)
        dataset = Datax()
        sampler = torch.utils.data.distributed.DistributedSampler(dataset, rank=rank, shuffle=False)
        loader = DataLoader(dataset, batch_size=8//world_size, shuffle=False, pin_memory=True, sampler=sampler)

    for i, data in enumerate(loader):
        data = data.cuda()
        l1 = net(data).mean()
        opt1.zero_grad()
        # when set the create_graph=True, error occurs in the second loop
        l1.backward(retain_graph=True)
        if rank <= 0:
            # first gradients and losses are all the same for all ranks
            print('grad1 : ', net.module.conv[0].weight.grad[0, 0])

        # net2 = net  # this line works fine, so the error comes from this line
        # However, all ranks outputs same architecture and same loss, but the grad is different.
        net2 = get_updated_network(net, net2, 1)
        if rank <= 0:
            for k, v in net2.named_parameters():
                print('net2 : ', v)

        l2 = net2(data).mean()
        l2.backward()
        if rank != -1:
            dist.all_reduce(l1), dist.all_reduce(l2)
            l1 = l1 / world_size
            l2 = l2 / world_size
        if rank <= 0:
            print('grad2 : ', net.module.conv[0].weight.grad[0, 0])
            print('loss1 : {}, loss2 : {}'.format(l1.item(), l2.item()))

        opt1.step()

        if i == 0:
            break


def dist_train(proc, ngpus_per_node, args):
    backend = 'nccl'
    url = 'tcp://127.0.0.1:23458'
    world_size = args
    dist.init_process_group(backend=backend, init_method=url, world_size=world_size, rank=proc)
    torch.cuda.set_device(proc)
    train(proc, world_size)


if __name__ == '__main__':
    train(-1, 4)
    for i in [1, 2, 4]:
        print('\n')
        ranks = i
        torch.manual_seed(1234)
        mp.spawn(dist_train, nprocs=ranks, args=(ranks, ranks))

1 Ответ

0 голосов
/ 04 мая 2020

Я нашел решение для второй проблемы. в последнем предупреждении о Pytorch DistributedDataParallel говорится:

Никогда не пытайтесь изменить параметры вашей модели после обертывания вашей модели с DistributedDataParallel. Другими словами, при обёртывании вашей модели с DistributedDataParallel конструктор DistributedDataParallel будет регистрировать дополнительные функции уменьшения градиента для всех параметров самой модели во время построения. Если вы измените параметры модели после построения DistributedDataParallel, это не будет поддерживаться, и может произойти непредвиденное поведение, поскольку функции уменьшения градиента некоторых параметров могут не вызываться.

Поскольку я изменил параметры одной сети градиенты не могут быть усреднены автоматически. Таким образом, вторая проблема может быть решена путем усреднения градиента для распределенных моделей самостоятельно.

def average_gradients(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
        param.grad.data /= size

loss.backward()
average_gradients(net)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...