Реализация группы Лассо на весовых матрицах PyTorch - PullRequest
2 голосов
/ 03 мая 2019

Я пытаюсь реализовать Group Lasso на весовых матрицах нейронной сети в PyTorch.

Я написал код для реализации Group Lasso, но не уверен, что это правильно, подтверждение или исправление моего кода будетбыть очень полезным.

def gl_norm(model, gl_lambda, num_blk):
    gl_reg = torch.tensor(0., dtype=torch.float32).cuda()
    for key in model:
        for param in model[key].parameters():
            dim = param.size()
            if dim.__len__() > 1 and not model[key].skip_regularization:
                div1 = list(torch.chunk(param,int(num_blk),1))
                all_blks = []
                for div2 in div1:
                    temp = list(torch.chunk(div2,int(num_blk),0))
                    for blk in temp:
                        all_blks.append(blk)
                for l2_param in all_blks:
                    gl_reg += torch.norm(l2_param, 2)
    return gl_reg * float(gl_lambda)

Я ожидаю, что функция torch.chunk разбивает матрицу весов на маленькие блоки, которые затем проходят норму L2 для блока и норму L1 между всеми блоками.

...