Функция для перекрестной проверки и передискретизации (SMOTE) - PullRequest
0 голосов
/ 15 мая 2019

Я написал следующий код.X - это фрейм данных с формой (1000,5), а y - это фрейм данных с формой (1000,1).y - это целевые данные для прогнозирования, и они несбалансированы.Я хочу применить перекрестную проверку и SMOTE.

def Learning(n, est, X, y):
    s_k_fold = StratifiedKFold(n_splits = n)
    acc_scores = []
    rec_scores = []
    f1_scores = []

    for train_index, test_index in s_k_fold.split(X, y): 
        X_train = X[train_index]
        y_train = y[train_index]    

        sm = SMOTE(random_state=42)
        X_resampled, y_resampled = sm.fit_resample(X_train, y_train)

        X_test = X[test_index]
        y_test = y[test_index]

        est.fit(X_resampled, y_resampled)
        y_pred = est.predict(X_test)
        acc_scores.append(accuracy_score(y_test, y_pred))
        rec_scores.append(recall_score(y_test, y_pred))
        f1_scores.append(f1_score(y_test, y_pred)) 

    print('Accuracy:',np.mean(acc_scores))
    print('Recall:',np.mean(rec_scores))
    print('F1:',np.mean(f1_scores)) 

Learning(3, SGDClassifier(), X_train_s_pca, y_train)

Когда я запускаю код, я получаю следующую ошибку:

Ни один из [Int64Index ([4231, 4235, 4246, 4250, 4255, 4295, 4317, 4344, 4381, \ n 4387, \ n ... \ n 13122, 13123, 13124, 13125, 13126, 13127, 13128, 13129, 13130, \ n
13131],\ n dtype = 'int64', length = 8754)] находятся в [столбцах] "

Помощь в его запуске приветствуется.

1 Ответ

0 голосов
/ 15 мая 2019

Если вы внимательно наблюдаете трассировку стека ошибок (что важно, но вы не включаете), вы должны увидеть, что ошибка происходит из этих строк (и будет происходить из других аналогичных строк):

X_train = X[train_index]

Этот способ выбора строк применим только для массива Numpy. Поскольку вы используете Pandas DataFrame, вы должны использовать loc :

X_train = X.loc[train_index]

В качестве альтернативы вы можете вместо этого преобразовать DataFrame в массив Numpy (чтобы минимизировать изменение кода), используя значения :

Learning(3, SGDClassifier(), X_train_s_pca.values, y_train.values)
...