Что у меня есть
У меня есть конвейер, который работает с моими дистрибутивами гиперпараметров
pipe = Pipeline(steps=[
('scale', MinMaxScaler()),
('vt', VarianceThreshold()),
('pca', PCA(random_state=0)),
('select', SelectPercentile()),
('clf', RandomForestClassifier(random_state=0))
])
hyper_params0 = {
'vt__threshold' : stats.distributions.uniform(0, 0.1),
'pca__n_components' : stats.distributions.uniform(0.8, 0.19),
'select__percentile' : stats.distributions.randint(1, 101),
'clf__n_estimators' : stats.distributions.randint(50, 1000),
'clf__criterion' : ['gini', 'entropy'],
'clf__min_samples_split' : stats.distributions.uniform(0, 0.1),
'clf__min_samples_leaf' : stats.distributions.uniform(0, 0.1),
'clf__max_features' : ['sqrt', 'log2', None],
'clf__bootstrap' : [True, False],
}
hyper_params=[
{
**hyper_params0,
**{
'select__score_func' : [mutual_info_classif],
}
},
{
**hyper_params0,
**{
'select__score_func' : [f_classif],
}
}
]
rscv = RandomizedSearchCV(
estimator=pipe,
param_distributions=hyper_params,
n_iter=25,
cv=5,
scoring='f1_macro',
n_jobs=-1,
random_state=0,
verbose=3
)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
rscv.fit(X_train, y_train)
Что я хочу
Я бы хотел выполнить поиск по n_neighbors
параметр в mutual_info_classif
в SelectPercentile
.
Что я пробовал
Я пробовал редактировать hyper_params
вот так:
hyper_params=[
{
**hyper_params0,
**{
'select__score_func' : [mutual_info_classif],
'select__score_func__n_neighbors' : stats.distributions.randint(3, 15)
}
},
{
**hyper_params0,
**{
'select__score_func' : [f_classif],
}
}
]
Но я получаю ошибка AttributeError: 'function' object has no attribute 'set_params'
. Я последовал простому примеру на сайте scikit-learn здесь , но далеко не продвинулся. Также пробовал использовать 'passthrough'
, например:
pipe = Pipeline(steps=[
('scale', MinMaxScaler()),
('vt', VarianceThreshold()),
('pca', PCA(random_state=0)),
('select', 'passthrough'),
('clf', RandomForestClassifier(random_state=0))
])
...
hyper_params=[
{
**hyper_params0,
**{
'select__score_func' : [SelectPercentile(mutual_info_classif)],
'select__score_func__n_neighbors' : stats.distributions.randint(3, 15)
}
},
{
**hyper_params0,
**{
'select__score_func' : [SelectPercentile(f_classif)],
}
}
]
Но получаю ошибку AttributeError: 'str' object has no attribute 'set_params'
.
Вопрос
Есть какие-нибудь советы, как это сделать?