Перекрестная проверка в sklearn: нужно ли вызывать fit () и cross_val_score ()? - PullRequest
0 голосов
/ 14 мая 2018

Я хотел бы использовать перекрестную проверку в k-кратном порядке при изучении модели. Пока я делаю это так:

# splitting dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(dataset_1, df1['label'], test_size=0.25, random_state=4222)

# learning a model
model = MultinomialNB()
model.fit(X_train, y_train)
scores = cross_val_score(model, X_train, y_train, cv=5)

На этом шаге я не совсем уверен, должен ли я использовать model.fit () или нет, потому что в официальной документации sklearn они не подходят, а просто вызывают cross_val_score следующим образом (они не даже разбить данные на обучающие и тестовые наборы):

from sklearn.model_selection import cross_val_score
clf = svm.SVC(kernel='linear', C=1)
scores = cross_val_score(clf, iris.data, iris.target, cv=5)

Я бы хотел настроить гиперпараметры модели во время изучения модели. Какой правильный трубопровод?

Ответы [ 2 ]

0 голосов
/ 14 мая 2018

Если вы хотите сделать выбор гиперпараметра, посмотрите на RandomizedSearchCV или GridSearchCV. Если вы хотите впоследствии использовать лучшую модель, затем вызовите любую из них с помощью refit=True, а затем используйте best_estimator_.

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RandomizedSearchCV

log_params = {'penalty': ['l1', 'l2'], 'C': [1E-7, 1E-6, 1E-6, 1E-4, 1E-3]}
clf = LogisticRegression()
search = RandomizedSearchCV(clf, scoring='average_precision', cv=10,
                            n_iter=10, param_distributions=log_params,
                            refit=True, n_jobs=-1)
search.fit(X_train, y_train)
clf = search.best_estimator_

http://scikit -learn.org / стабильный / модули / полученные / sklearn.model_selection.RandomizedSearchCV.html

0 голосов
/ 14 мая 2018

Ваш второй пример подходит для перекрестной проверки. Смотрите пример здесь: http://scikit -learn.org / stable / modules / cross_validation.html # computing-cross-validated-metrics

Подгонка будет выполнена внутри функции cross_val_score, вам не нужно заранее об этом беспокоиться.

[Отредактировано] Если, помимо перекрестной проверки, вы хотите обучить модель, вы можете позвонить model.fit() впоследствии.

...