Склеарнская модель случайного леса слишком велика - PullRequest
1 голос
/ 12 марта 2019

Вопрос от новичка в sklearn, пожалуйста, сообщите. У меня RandomForestClassifier модель обучена со следующими параметрами:

n_estimators = 32,
criterion = 'gini',
max_depth = 380,

Эти параметры выбраны не случайно, по какой-то причине они показали лучшую производительность ... хотя и кажутся мне странными.

Размер модели составляет около 5,5 ГБ при сохранении с joblib.dump и compress=3

Используемые данные:

tfidf=TfidfVectorizer()
X_train=tfidf.fit_transform(X_train)

и

le=LabelEncoder()
le.fit(y_train)
y_train=le.fit_transform(y_train)

с размером выборки 4,7Mio записей, разделенных 0,3 (70% поезд, 30% тест)

Теперь у меня есть вопрос, может быть, кто-то может помочь с:

Имеет ли для вас смысл параметры, используемые для модели, и размер модели по отношению к размеру выборки? Вероятно, выбор параметров не является оптимальным для модели, которая увеличивает размер (я понимаю, что основной параметр, увеличивающий размер, равен max_depth, но результат был лучшим ...)

Возможно, есть какие-либо предложения по параметрам или подготовке данных в целом, так как в моем опыте с этим образцом я заметил следующее: 1. Увеличение n_estimators практически не влияет на результат; 2. Увеличение max_depth, с другой стороны, приносит значительные улучшения. Как пример: - max_depth = 10 - accuracy_score из 0,3 - max_depth = 380 - accuracy_score 0,95

Любые предложения, совет очень приветствуется! :) 1037 *

UPD. Точность результатов

Оценка поезда: 0,988 classifier.score

Оценка OOB: 0,953 classifier.oob_score_

Результаты тестов: 0,935 sklearn.metrics -> accuracy_score

1 Ответ

0 голосов
/ 17 марта 2019

Попробуйте использовать min_samples_leaf вместо max_depth для ограничения глубины дерева. Это допускает разные глубины для разных путей дерева и для разных оценщиков. Надеемся, что это позволит найти модель, которая имеет хорошие характеристики с меньшей средней глубиной. Мне нравится устанавливать min_samples_leaf как число с плавающей точкой, что означает долю от количества сэмплов. Попробуйте gridsearch между (0,0001, 0,1)

...