Я хочу использовать классификатор случайных лесов для несбалансированных данных, где X - это массив np.ar, представляющий объекты, а y - массив np.ar, представляющий метки (метки с 90% 0-значениями и 10% 1-значениями) , Поскольку я не был уверен, как выполнить стратификацию в рамках перекрестной проверки, и если это имеет значение, я также вручную проверял перекрестную проверку с помощью StratifiedKFold. Я ожидаю не такие же, но несколько похожие результаты. Поскольку это не тот случай, я предполагаю, что я ошибочно использую один метод, но я не понимаю, какой. Вот код
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split
from sklearn.metrics import f1_score
rfc = RandomForestClassifier(n_estimators = 200,
criterion = "gini",
max_depth = None,
min_samples_leaf = 1,
max_features = "auto",
random_state = 42,
class_weight = "balanced")
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size = 0.20, random_state = 42, stratify=y)
Я также попробовал классификатор без аргумента class_weight. Отсюда я перехожу к сравнению обоих методов с f1-баллом
cv = cross_val_score(estimator=rfc,
X=X_train_val,
y=y_train_val,
cv=10,
scoring="f1")
print(cv)
. 10 баллов по f1 из перекрестной проверки составляют около 65%. Теперь StratifiedKFold:
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
for train_index, test_index in skf.split(X_train_val, y_train_val):
X_train, X_val = X_train_val[train_index], X_train_val[test_index]
y_train, y_val = y_train_val[train_index], y_train_val[test_index]
rfc.fit(X_train, y_train)
rfc_predictions = rfc.predict(X_val)
print("F1-Score: ", round(f1_score(y_val, rfc_predictions),3))
10 f1-баллов от StraifiedKFold дают мне значения около 90%. Вот где я запутался, так как не понимаю больших отклонений между обоими методами. Если я просто подгоняю классификатор к данным поезда и применяю его к тестовым данным, я получаю также f1-баллы около 90%, что позволяет мне полагать, что мой способ применения cross_val_score неправильный.