ColumnTransformer fit_transform не работает с конвейером - PullRequest
0 голосов
/ 06 августа 2020

Пишу конвейер с нестандартным трансформатором. При вызове fit_transform категориального конвейера я получаю желаемый результат, но при вызове fit_transform из ColumnTransformer все, что я инициализировал в init настраиваемого трансформатора, теряется. Примечание: не включая код числового преобразователя для удобочитаемости

class categoryTransformer(BaseEstimator, TransformerMixin):
def __init__(self, use_dates=['year', 'month', 'day']):
    self._use_dates = use_dates
    print('==========>',self._use_dates)
def fit(self, X, y=None):
    return self

def get_year(self, obj):
    return str(obj)[:4]

def get_month(self, obj):
    return str(obj)[4:6]

def get_day(self, obj):
    return str(obj)[6:8]

def create_boolean(self, obj):
    if obj == '0':
        return 'No'
    else:
        return 'Yes'
    
def transform(self, X, y=None):
    print(self._use_dates)

     for spec in self._use_dates:
         print(spec)
         exec("X.loc[:,'{}'] = X['date'].apply(self.get_{})".format(spec, spec))
    
    X = X.drop('date', axis=1)
    X.loc[:,'yr_renovated'] = X['yr_renovated'].apply(self.create_boolean)
    X.loc[:, 'view'] = X['view'].apply(self.create_boolean)
    return X.values

cat_pipe = Pipeline([
('cat_transform', categoryTransformer()),
('one_hot', OneHotEncoder(sparse=False))])

num_pipe = Pipeline([
('num_transform', numericalTransformer()),
('imputer', SimpleImputer(strategy = 'median')),
('std_scaler', StandardScaler())])

full_pipe = ColumnTransformer([
('num', num_pipe, numerical_features),
('cat', cat_pipe, categorical_features)])

cat_pipe.fit_transform(data[categorical_features])#working fine
df2 = full_pipe.fit_transform(X_train)# __init__ initialisation lost

"output"
==========> ['year', 'month', 'day']
['year', 'month', 'day']
year
month
day
==========> None
None

После этой длинной трассировки, которую я не могу отлаживать. Обходной путь - если я могу создать use_dates = ['year', 'month', 'day'] в самой функции преобразования, но я хочу понять, почему это происходит.

1 Ответ

0 голосов
/ 07 августа 2020

Параметры __init__ должны иметь те же имена, что и устанавливаемые атрибуты (поэтому use_dates и _use_dates являются проблемой).

Это требуется для правильной работы клонирования, и ColumnTransformer клонирует все свои трансформаторы перед установкой.

https://scikit-learn.org/stable/developers/develop.html#instantiation

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...