Ошибка создания собственного преобразователя в sklearn - принимает 2 позиционных аргумента, но 3 были заданы - PullRequest
0 голосов
/ 09 апреля 2020

Я пытаюсь создать конвейер, используя пользовательские трансфомеры для обобщенного набора данных c. Вот мой первый трансформер. Учитывая имя столбца, он разбивает этот столбец даты и времени на следующие столбцы.

class DatePartTransformer:

    def __init__(self,fldname):
        self.fldname = fldname

    def fit(self):
        return self 

    def transform(self):
        return self

    def fit_transform(self,df, drop=True, time=False, errors='raise'):
        fld = df[self.fldname]
        fld_dtype = fld.dtype
        if isinstance(fld_dtype, pd.core.dtypes.dtypes.DatetimeTZDtype):
            fld_dtype = np.datetime64

        if not np.issubdtype(fld_dtype, np.datetime64):
            df[self.fldname] = fld = pd.to_datetime(fld, infer_datetime_format=True, errors=errors)
        targ_pre = re.sub('[Dd]ate$', '', self.fldname)
        attr = ['Year', 'Month', 'Week', 'Day', 'Dayofweek', 'Dayofyear',
            'Is_month_end', 'Is_month_start', 'Is_quarter_end', 'Is_quarter_start', 'Is_year_end', 'Is_year_start']
        if time: attr = attr + ['Hour', 'Minute', 'Second']
        for n in attr: df[targ_pre + n] = getattr(fld.dt, n.lower())
        df[targ_pre + 'Elapsed'] = fld.astype(np.int64) // 10 ** 9
        if drop: df.drop(self.fldname, axis=1, inplace=True)

        return df

и вот мой второй

from pandas.api.types import is_string_dtype

class TrainCats:

    def __init__(self):
        pass

    def fit(self):
        return self

    def transform(self):
        return self

    def fit_transform(self,df):

        for n,c in df.items():
            if is_string_dtype(c): 
                df[n] = c.astype('category').cat.as_ordered()
        return df

Я планирую написать больше.

Здесь это трубопровод.

pipeline = Pipeline([
 ('imputer',DatePartTransformer('date')),
 ('cats',TrainCats())
])

df = pipeline.fit_transform(df_raw)

Когда я запускаю конвейер, я получаю эту ошибку

TypeError                                 Traceback (most recent call last)
<ipython-input-36-36154d1b45b5> in <module>
      4 ])
      5 
----> 6 df = pipeline.fit_transform(df_raw)

c:\users\vishak~1\desktop\env\ml\lib\site-packages\sklearn\pipeline.py in fit_transform(self, X, y, **fit_params)
    391                 return Xt
    392             if hasattr(last_step, 'fit_transform'):
--> 393                 return last_step.fit_transform(Xt, y, **fit_params)
    394             else:
    395                 return last_step.fit(Xt, y, **fit_params).transform(Xt)

TypeError: fit_transform() takes 2 positional arguments but 3 were given

В книге Аурелиана Герона говорится, что так работают конвейеры. Я не могу найти свою ошибку.

1 Ответ

1 голос
/ 10 апреля 2020

Если вы посмотрите на исходный код Pipeline, вы увидите, что для каждого преобразователя требуется принять 2 позиционных аргумента, то есть X и y (кроме self) при использовании fit_transform метод. Это именно эта строка:

                 return last_step.fit_transform(Xt, y, **fit_params)

Таким образом, объявление метода fit_transform вашего преобразователя должно иметь 2 позиционных аргумента. Чтобы исправить это, все, что вам нужно сделать, это предоставить второй фиктивный аргумент для вашего TrainCats fit_transform метода, например:

    def fit_transform(self,df, y=None):

        for n,c in df.items():
            if is_string_dtype(c): 
                df[n] = c.astype('category').cat.as_ordered()
        return df

Это уменьшит вашу ошибку, но есть еще одна уязвимость. Хотя ваш fit_transform в DatePartTransformer принимает более 1 аргумента, из-за предположения о трубопроводе ваш аргумент drop будет переопределен с None или фактическим y от другого преобразователя. Если вы планируете работать только с входными данными, а не с метками, вам также необходимо добавить этот фиктивный аргумент в DatePartTransformer:

     def fit_transform(self,df, y=None, drop=True, time=False, errors='raise'):
        ...

...