Краткое описание проблемы
Я использую функцию validation_curve () в Scikit-Learn для построения кривых сложности модели при использовании GridSearchCV (и RandomizedSearchCV) на MLPClassifier , но функция validation_curve () возвращает ошибку:
TypeError: объект 'rv_frozen' не повторяется
Поиск варьируется по альфе и hidden_layer_sizes.Эта ошибка возникает только тогда, когда n_jobs больше 1 .Проблема заключается в запуске validation_curve параллельно с функцией при использовании MLPClassifier.
Есть ли способ решить эту проблему, не уменьшая n_jobs до 1?
Попытки разрешить
Единственный способ решить эту проблему - установить n_jobs = 1. Однако этот код перекрестной проверки работает для других моделей, таких как GradientBoostingClassifier.
Я пробовал как GridSearchCV, так и RandomizedSearchCV.и уменьшил пространство параметров без разрешения.Я также попробовал это на двух уникальных наборах данных и получил ту же ошибку.Я также перешел от использования случайного KFold с 5 разделениями к простой записи cv = 5 в validation_curve ().
Код
def create_validation_curve(df, pipeline, scoring, s_cv, X_train, y_train, param_name, param_range, n_jobs):
df_features = df.drop('target', axis=1)
df_target = df['target']
cv = KFold(n_splits=5, shuffle=True)
cv = cv.split(X_train, y_train)
train_scores, test_scores = validation_curve(
estimator=s_cv.best_estimator_, X=X_train, y=y_train,
param_name=param_name, param_range=param_range, cv=5, scoring=scoring, n_jobs=n_jobs)
Ожидаемые результаты
Я не сделалполучить эту ошибку для других моделей, таких как GradientBoostingClassifier или DecisionTreeClassifier.
Полная ошибка
Полная ошибка приведена ниже:
TypeError Traceback (most recent call last)
<ipython-input-11-363a4bc65384> in <module>
111 param_range = sp_randint(50, 200)
112 train_scores, test_scores = create_validation_curve(
--> 113 df, pl_classifier, scoring, rs_cv, X_train, y_train, param_name, param_range, n_jobs)
114
115 plot_validation_curve(param_range, train_scores, test_scores, param_name, alpha=0.1)
<ipython-input-11-363a4bc65384> in create_validation_curve(df, pipeline, scoring, s_cv, X_train, y_train, param_name, param_range, n_jobs)
48 train_scores, test_scores = validation_curve(
49 estimator=s_cv.best_estimator_, X=X_train, y=y_train,
---> 50 param_name=param_name, param_range=param_range, cv=5, scoring=scoring, n_jobs=n_jobs)
51
52 return train_scores, test_scores
~/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_validation.py in validation_curve(estimator, X, y, param_name, param_range, groups, cv, scoring, n_jobs, pre_dispatch, verbose, error_score)
1445 error_score=error_score)
1446 # NOTE do not change order of iteration to allow one time cv splitters
-> 1447 for train, test in cv.split(X, y, groups) for v in param_range)
1448 out = np.asarray(out)
1449 n_params = len(param_range)
~/anaconda3/lib/python3.7/site-packages/sklearn/externals/joblib/parallel.py in __call__(self, iterable)
915 # remaining jobs.
916 self._iterating = False
--> 917 if self.dispatch_one_batch(iterator):
918 self._iterating = self._original_iterator is not None
919
~/anaconda3/lib/python3.7/site-packages/sklearn/externals/joblib/parallel.py in dispatch_one_batch(self, iterator)
752 tasks = BatchedCalls(itertools.islice(iterator, batch_size),
753 self._backend.get_nested_backend(),
--> 754 self._pickle_cache)
755 if len(tasks) == 0:
756 # No more tasks available in the iterator: tell caller to stop.
~/anaconda3/lib/python3.7/site-packages/sklearn/externals/joblib/parallel.py in __init__(self, iterator_slice, backend_and_jobs, pickle_cache)
208
209 def __init__(self, iterator_slice, backend_and_jobs, pickle_cache=None):
--> 210 self.items = list(iterator_slice)
211 self._size = len(self.items)
212 if isinstance(backend_and_jobs, tuple):
~/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_validation.py in <genexpr>(.0)
1445 error_score=error_score)
1446 # NOTE do not change order of iteration to allow one time cv splitters
-> 1447 for train, test in cv.split(X, y, groups) for v in param_range)
1448 out = np.asarray(out)
1449 n_params = len(param_range)
TypeError: 'rv_frozen' object is not iterable