Наследование от SciKit FunctionTransformer - PullRequest
0 голосов
/ 25 мая 2020

Я хотел бы использовать FunctionTransformer и в то же время предоставить простой API и скрыть дополнительные детали. В частности, я хотел бы предоставить класс Custom_Trans, как показано ниже. Таким образом, вместо trans1, который работает нормально, пользователь должен иметь возможность использовать trans2, который в данный момент не работает:

from sklearn import preprocessing 
from sklearn.pipeline import Pipeline
from sklearn import model_selection
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression
import numpy as np

X, y = make_regression(n_samples=100, n_features=1, noise=0.1)

def func(X, a, b):
    return X[:,a:b]

class Custom_Trans(preprocessing.FunctionTransformer):
    def __init__(self, ind0, ind1):
        super().__init__(
            func=func,
            kw_args={
                "a": ind0,
                "b": ind1
            }
        )

trans1 = preprocessing.FunctionTransformer( 
    func=func,
    kw_args={
        "a": 0,
        "b": 50
    }
)

trans2 = Custom_Trans(0,50)

pipe1 = Pipeline(
    steps=[
           ('custom', trans1),
           ('linear', LinearRegression())
         ]
)

pipe2 = Pipeline(
    steps=[
           ('custom', trans2),
           ('linear', LinearRegression())
          ]
)

print(model_selection.cross_val_score(
    pipe1, X, y, cv=3,)
)

print(model_selection.cross_val_score(
    pipe2, X, y, cv=3,)
)

Вот что я получаю:

[0.99999331 0.99999671 0.99999772]
...sklearn/base.py:209: FutureWarning: From version 0.24, get_params will raise an
AttributeError if a parameter cannot be retrieved as an instance attribute. 
Previously it would return None.
warnings.warn('From version 0.24, get_params will raise an '
...
[0.99999331 0.99999671 0.99999772]

Я знаю, что это связано с клонированием оценщика, но не знаю, как это исправить. Например, этот пост говорит, что

не должно быть logi c, даже проверки ввода, в оценщиках init . Лог c должен быть помещен там, где используются параметры, что обычно соответствует

, но в этом случае мне нужно передать параметры суперклассу. Нет возможности поместить лог c в fit(). Что я могу сделать?

1 Ответ

0 голосов
/ 25 мая 2020

Вы можете получить get_params, унаследовав от BaseEstimator.

class FunctionTransformer(BaseEstimator, TransformerMixin)

Как передать параметры в класс customize modeltransformer

наследовать от function_transformer

нестандартные трансформаторы

У вас в базе:

def get_params(self, deep=True):
        """
        Get parameters for this estimator.
        Parameters
        ----------
        deep : bool, default=True
            If True, will return the parameters for this estimator and
            contained subobjects that are estimators.
        Returns

Измените код:

trans1 = dict(
    functiontransformer__kw_args=[
        {'ind0': None},
        {'ind0': [1]}
    ]
)

class Custom_Trans(preprocessing.FunctionTransformer): 
    def __init__(self, ind0, ind1, deep=True): 
        super().__init__( func=func, kw_args={ "a": ind0, "b": ind1 } ) 
        self.ind0 = ind0
        self.ind1 = ind1
        self.deep = True 
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...