Пользовательский регрессор: GridSearchCV говорит, что «get_params» не реализуется при наследовании от BaseEsitmator - PullRequest
0 голосов
/ 15 апреля 2020

Здравствуйте,

Спасибо за то, что нашли время взглянуть на это.

Я работаю над реализацией этой версии scikit-learn API сообщение в блоге , данные доступны здесь . Мой пользовательский класс воспроизводит результаты авторов, но не работает с GridSearchCV.

По сути, он реализует частичную регрессию наименьших квадратов для некоторых спектральных данных, при этом оптимальное количество компонентов определяется как число компонентов, дающих самое низкое MSE. Моя попытка показана ниже, я могу повторить результаты MSE авторов для оптимальной калибровки, и параметры по умолчанию __init__ ниже установлены на эти параметры. Обратите внимание, что я наследую от BaseEstiamtor и RegressorMixin.

#download the .csv from the github repo from the blog post
#Creating df, shuffling, then creating `X` and `y`

df = pd.read_csv("nirpyresearch/data/peach_spectra+brixvalues.csv")
df = df.sample(replace=False, frac=1).copy()
y = df['Brix'].values
X = df[[i for i in list(df.columns) if 'wl' in i]].values
class SavgolPLS(BaseEstimator, RegressorMixin):
    """My Regressor"""
    def __init__(self,  savgol_window = 17, savgol_polyorder = 2, savgol_deriv = 2, pls_components = 7 ):
        self.savgol_window = savgol_window
        self.savgol_polyorder = savgol_polyorder
        self.savgol_deriv = savgol_deriv
        self.pls_components = pls_components

    def fit(self, X, y):

        # Check that X and y have correct shape
        X, y = check_X_y(X, y)


        self.X_ = X
        self.y_ = y
        self.X_savgol_ = savgol_filter(X, self.savgol_window, self.savgol_polyorder, self.savgol_deriv)
        self.pls_ = PLSRegression(n_components=self.pls_components).fit(self.X_savgol_, self.y_)
        # Return the classifier
        return self

    def predict(self, X, apply_savgol = True):

        # Check is fit had been called
        #check_is_fitted(self)

        # Input validation
        X = check_array(X)
        if apply_savgol:
            X = savgol_filter(X, self.savgol_window, self.savgol_polyorder, self.savgol_deriv)
        pred_y = self.pls_.predict(X)
        return pred_y

    def score(self, y_pred):
        mse = mean_squared_error( y_true = self.y_, y_pred=y_pred,)
        return mse


Теперь я могу инициализировать модель и использовать .get_params(), чтобы получить диктат, содержащий 4 параметра в __init__.

s_pls = SavgolPLS(pls_components=7)
s_pls.get_params()

Таким образом, get_params(), кажется, существует. , , что имеет смысл, учитывая, что он был унаследован от BaseEstimator. Я также могу использовать метод fit() для репликации результатов авторов.

s_pls = s_pls.fit(X = X, y = y)
y_pred = s_pls.predict(X)

#This should be ~0.6566
s_pls.score(y_pred)

Почему тогда применение GridSearchCV в приведенном ниже коде приводит к отображаемой ошибке?

parameters  ={'savgol_window':[3,30], 'savgol_polyorder':[2,4], 'savgol_deriv':[1,3], 'pls_components':[2,15]}
clf = GridSearchCV(SavgolPLS, parameters, cv = 10)
clf.fit(X, y)

Выход

TypeError                                 Traceback (most recent call last)
<ipython-input-22-e20c1eabb4fa> in <module>
----> 1 clf.fit(X, y.ravel())

C:\tools\Anaconda3\envs\dev_py37_tf\lib\site-packages\sklearn\model_selection\_search.py in fit(self, X, y, groups, **fit_params)
    631         n_splits = cv.get_n_splits(X, y, groups)
    632 
--> 633         base_estimator = clone(self.estimator)
    634 
    635         parallel = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,

C:\tools\Anaconda3\envs\dev_py37_tf\lib\site-packages\sklearn\base.py in clone(estimator, safe)
     58                             "it does not seem to be a scikit-learn estimator "
     59                             "as it does not implement a 'get_params' methods."
---> 60                             % (repr(estimator), type(estimator)))
     61     klass = estimator.__class__
     62     new_object_params = estimator.get_params(deep=False)

TypeError: Cannot clone object '<class '__main__.SavgolPLS'>' (type <class 'type'>): it does not seem to be a scikit-learn estimator as it does not implement a 'get_params' methods.

Спасибо за помощь!

1 Ответ

0 голосов
/ 15 апреля 2020

Вы передаете класс в GridSearchCV, вам следует передать экземпляр: clf = GridSearchCV(SavgolPLS(), parameters, cv = 10)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...