Как добавить регуляризацию L1 в модель PyTorch NN? - PullRequest
0 голосов
/ 30 сентября 2019

При поиске способов реализации регуляризации L1 в моделях PyTorch я наткнулся на этот вопрос , которому уже 2 года, поэтому мне было интересно, есть ли что-нибудь новое по этой теме?

Я также нашел этот недавний подход к отсутствующей функции l1. Однако я не понимаю, как использовать его для базового NN, как показано ниже.

class FFNNModel(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, dropout_rate):
        super(FFNNModel, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate
        self.drop_layer = nn.Dropout(p=self.dropout_rate)
        self.fully = nn.ModuleList()
        current_dim = input_dim
        for h_dim in hidden_dim:
            self.fully.append(nn.Linear(current_dim, h_dim))
            current_dim = h_dim
        self.fully.append(nn.Linear(current_dim, output_dim))

    def forward(self, x):
        for layer in self.fully[:-1]:
            x = self.drop_layer(F.relu(layer(x)))
        x = F.softmax(self.fully[-1](x), dim=0)
        return x

Я надеялся, что просто поставить это перед тренировкой сработает:

model = FFNNModel(30,5,[100,200,300,100],0.2)
regularizer = _Regularizer(model)
regularizer = L1Regularizer(regularizer, lambda_reg=0.1)

с

out = model(inputs)
loss = criterion(out, target) + regularizer.__add_l1()

Кто-нибудь понимает, как применять эти «готовые к использованию»? 'classes?

РЕДАКТИРОВАТЬ / ПРОСТОЕ РЕШЕНИЕ: для всех, кто сталкивался с этим:

Всегда были некоторые проблемы с классами Regularizer_ по ссылке выше, поэтому я решил проблемуиспользуя обычные функции, добавив также ортогональный регуляризатор:

def l1_regularizer(model, lambda_l1=0.01):
    lossl1 = 0
    for model_param_name, model_param_value in model.named_parameters():
            if model_param_name.endswith('weight'):
                lossl1 += lambda_l1 * model_param_value.abs().sum()
    return lossl1    

def orth_regularizer(model, lambda_orth=0.01):
    lossorth = 0
    for model_param_name, model_param_value in model.named_parameters():
            if model_param_name.endswith('weight'):
                param_flat = model_param_value.view(model_param_value.shape[0], -1)
                sym = torch.mm(param_flat, torch.t(param_flat))
                sym -= torch.eye(param_flat.shape[0])
                lossorth += lambda_orth * sym.sum()
    return lossorth  

и во время тренировки:

loss = criterion(outputs, y_data)\
      +l1_regularizer(model, lambda_l1=lambda_l1)\
      +orth_regularizer(model, lambda_orth=lambda_orth)   

1 Ответ

3 голосов
/ 30 сентября 2019

Я не запускал рассматриваемый код, поэтому, пожалуйста, свяжитесь с нами, если что-то не работает. В целом, я бы сказал, что код, который вы связали, излишне сложен (это может быть потому, что он пытается быть универсальным и допускает все следующие виды регуляризации). Я предполагаю, что

model = FFNNModel(30,5,[100,200,300,100],0.2)
regularizer = L1Regularizer(model, lambda_reg=0.1)

, а затем

out = model(inputs)
loss = criterion(out, target) + regularizer.regularized_all_param(0.)

Вы можете проверить, что regularized_all_param будет просто повторять по параметрамвашей модели, и если их имя оканчивается на weight, она накапливает их сумму абсолютных значений. По какой-то причине буфер должен быть инициализирован вручную, поэтому мы передаем 0..

Действительно, хотя, если вы хотите эффективно упорядочить L1 и не нуждаетесь в каких-либо прибамбасах, тем более ручнымподход, родственный вашей первой ссылке, будет более читабельным. Это будет выглядеть так:

l1_regularization = 0.
for param in model.parameters():
    l1_regularization += param.abs().sum()
loss = criterion(out, target) + l1_regularization

Это действительно то, что лежит в основе обоих подходов. Вы используете метод Module.parameters для итерации по всем параметрам модели и суммируете их нормы L1, которые затем становятся термином в вашей функции потерь. Вот и все. Репо, которое вы связали, предлагает какой-то причудливый механизм для его отвлечения, но, судя по вашему вопросу, не получается:)

...