Классификатор Scikit Learn со свойствами смешанного типа возвращает 0% точности с тестовыми данными - PullRequest
1 голос
/ 08 октября 2019

Я новичок в машинном обучении и Python. Я хочу использовать DecisionTreeClassifier от sklearn. Так как мои функции являются частично числовыми и частично категориальными, мне нужно преобразовать их, потому что DecisionTreeClassifier принимает только числовые функции в качестве входных данных. Для этого я использую ColumnTransformer и конвейеры. Идея заключается в следующем:

  1. Категориальные и числовые признаки преобразуются в отдельные конвейеры
  2. Оба вместе образуют вход для классификатора

ОднакоТочность использования моих тестовых данных всегда равна 0%, в то время как моя точность с данными тренировок составляет ~ 85%. Кроме того, вызов cross_val_score () возвращает

ValueError: Found unknown categories ['Holand-Netherlands'] in column 7 during transform

Это странно, потому что я использовал эти самые данные для обучения full_pipeline. Использование разных классификаторов приводит к одинаковому поведению, что приводит меня к мысли, что есть проблема с преобразованиями. Помощь очень ценится!

Ниже мой код:

names = ["age",
         "workclass",
         "final-weight",
         "education",
         "education-num",
         "martial-status",
         "occupation",
         "relationship",
         "race",
         "sex",
         "capital-gain",
         "capial-loss",
         "hours-per-week",
         "native-country",
         "agrossincome"]

categorical_features = ["workclass", "education", "martial-status", "occupation", "relationship", "race", "sex", "native-country"]
numerical_features = ["age","final-weight", "education-num", "capital-gain", "capial-loss", "hours-per-week"] 
features = np.concatenate([categorical_features, numerical_features])


# create pandas dataframe for adult dataset
adult_train = pd.read_csv(filepath_or_buffer= "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data" ,
            delimiter= ',',
            index_col = False,
            skipinitialspace = True,
            header = None,
            names = names )

adult_test = pd.read_csv( filepath_or_buffer= "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test" ,
            delimiter= ',',
            index_col = False,
            skipinitialspace = True,
            header = None,
            names = names )

adult_test.drop(0, inplace =True)
adult_test.reset_index(inplace = True)
adult_train.replace(to_replace= "?", value = np.NaN, inplace = True)
adult_test.replace(to_replace= "?", value = np.NaN, inplace= True)


# split data into features and targets
x_train = adult_train[features]
y_train = adult_train.agrossincome

x_test = adult_test[features]
y_test = adult_test.agrossincome


# create pipeline for preprocessing + classifier
categorical_pipeline = Pipeline( steps = [ ( 'imputer', SimpleImputer(strategy='constant', fill_value='missing') ),
                                           ( 'encoding', OrdinalEncoder() ) 
                                         ])

numerical_pipeline = Pipeline( steps = [ ( 'imputer', SimpleImputer(strategy='median') ),
                                         ( 'std_scaler', StandardScaler( with_mean = False ) ) 
                                       ])

preprocessing = ColumnTransformer( transformers = [ ( 'categorical_pipeline', categorical_pipeline, categorical_features ), 
                                                   ( 'numerical_pipeline', numerical_pipeline, numerical_features ) ] )

full_pipeline = Pipeline(steps= [ ('preprocessing', preprocessing),
                                  ('model', DecisionTreeClassifier(random_state= 0, max_depth = 5) ) ])

full_pipeline.fit(x_train, y_train)
print(full_pipeline.score(x_test, y_test))
#print(cross_val_score(full_pipeline, x_train, y_train, cv=3).mean())

1 Ответ

3 голосов
/ 08 октября 2019

Ошибка происходит от y_test, который выглядит как

enter image description here

, а

enter image description here

Удаление '.'в конце следует исправить это

enter image description here

...