Как применяются class_weights в логистической регрессии sklearn? - PullRequest
0 голосов
/ 20 мая 2018

Меня интересует, как sklearn применяет вес класса, который мы поставляем.В документации не указано, где и как применяются веса классов.Также не помогает чтение исходного кода (кажется, что для оптимизации используется sklearn.svm.liblinear, и я не могу прочитать исходные коды, поскольку это файл .pyd ...)

Но я думаю,он работает с функцией стоимости: когда указаны веса классов, стоимость соответствующего класса будет умножена на вес класса.Например, если у меня есть 2 наблюдения, каждое из которых относится к классу 0 (вес = 0,5) и классу 1 (вес = 1), соответственно, функция стоимости будет:

Стоимость = 0,5 * log (...X_0, y_0 ...) + 1 * log (... X_1, y_1 ...) + штраф

Кто-нибудь знает, правильно ли это?

1 Ответ

0 голосов
/ 20 мая 2018

Отметьте следующие строки в исходном коде :

le = LabelEncoder()
if isinstance(class_weight, dict) or multi_class == 'multinomial':
    class_weight_ = compute_class_weight(class_weight, classes, y)
    sample_weight *= class_weight_[le.fit_transform(y)]

Вот исходный код для функции compute_class_weight() :

...
else:
    # user-defined dictionary
    weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
    if not isinstance(class_weight, dict):
        raise ValueError("class_weight must be dict, 'balanced', or None,"
                         " got: %r" % class_weight)
    for c in class_weight:
        i = np.searchsorted(classes, c)
        if i >= len(classes) or classes[i] != c:
            raise ValueError("Class label {} not present.".format(c))
        else:
            weight[i] = class_weight[c]
...

В приведенном выше фрагменте class_weight применяются к sample_weight, который используется в нескольких внутренних функциях, таких как _logistic_loss_and_grad , _logistic_loss и т. Д.: * 10101

# Logistic loss is the negative of the log of the logistic function.
out = -np.sum(sample_weight * log_logistic(yz)) + .5 * alpha * np.dot(w, w)
# NOTE: --->  ^^^^^^^^^^^^^^^
...