Как извлечь важность функций из алгоритма Logitboost в настройке мультиклассовой классификации? - PullRequest
1 голос
/ 27 мая 2020

В настоящее время я использую алгоритм Logitboost с несколькими классами ( docs ), который отлично работает. Однако при попытке просмотреть важность различных функций я получаю следующее сообщение об ошибке:

NotImplementedError: Feature importances is currently only implemented for binary classification tasks.

При просмотре кода Github я не совсем понимаю, почему это еще не реализовано. Кто-нибудь знает способ обойти это, так что я могу определить важность функции, или я ничего не могу сделать, если не дождусь более новой версии Logitboost (что не кажется таким вероятным, поскольку последнее обновление было несколько месяцев назад go ).

Я уже пытался изменить файл Logitboost.py, но, поскольку у меня ограниченные знания о программировании, это довольно утомительный процесс.

Заранее спасибо!

1 Ответ

1 голос
/ 27 мая 2020
• 1000 пока не реализует этот метод для мультиклассовых задач. Однако, глядя на структуру подобранного классификатора, кажется довольно простым придумать некоторую особую метрику важности c.

Давайте посмотрим на примере, используя набор данных радужной оболочки:

import logitboost
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

X, y = load_iris(return_X_y=True)
X_train , X_test, y_train, y_test = train_test_split(X,y)
lg = logitboost.LogitBoost()
lg.fit(X_train, y_train)

Если вы посмотрите на lg.estimators_, вы увидите, что структура представляет собой вложенный список согласованных деревьев решений. Мы могли бы сделать следующее, чтобы получить общую важность:

l_feat_imp = [sum(cls.feature_importances_ for cls in cls_list) 
              for cls_list in lg.estimators_]
imp = np.array(l_feat_imp).sum(0)
# array([ 9., 19., 51., 71.])

т.е. это просто взятие суммы вкладов каждой функции для всех внутренних списков оценщиков, а затем снова суммированных вкладов. Итак, в этом случае у нас будет:

pd.Series(imp, index=load_iris().feature_names).sort_values(ascending=False).plot.bar()

enter image description here

...