Почему OneVsRestClassifier()
возвращает гораздо меньшую оценку для того же набора данных, чем просто использование параметра multi_class="ovr"
?
Используя простой способ подгонки и получения оценки с помощью logisit c регрессия:
#Load Data, assign variables
training_data = pd.read_csv("iris.data")
training_data.columns = [
"sepal_length",
"sepal_width",
"petal_length",
"petal_width",
"class",
]
feature_cols = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
label_cols = ["class"]
X = training_data.loc[:, feature_cols]
y = training_data.loc[:, label_cols].values.ravel()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
# Instantiate and fit the model:
logreg =LogisticRegression(solver="liblinear", multi_class="ovr", random_state=24)
clf = logreg.fit(X_train, y_train)
# See if the model is reasonable.
print("Score: ", clf.score(X_test, y_test))
Я получаю оценку 0.92
, а при использовании OneVsAllRegression
Я получаю оценку 0.62
training_data = pd.read_csv("iris.data")
training_data.columns = [
"sepal_length",
"sepal_width",
"petal_length",
"petal_width",
"class",
]
feature_cols = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
label_cols = ["class"]
X = training_data.loc[:, feature_cols]
y = training_data.loc[:, label_cols].values.ravel()
#transform lables to 0-1-2
le = preprocessing.LabelEncoder()
le.fit(training_data.loc[:, label_cols].values.ravel())
y=le.transform(training_data.loc[:, label_cols].values.ravel())
# Binarize the output
y = label_binarize(y, classes=[0, 1, 2])
n_classes = 3
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
# Instantiate and fit the model:
logreg = OneVsRestClassifier(LogisticRegression(solver="liblinear", random_state=24))
clf = logreg.fit(X_train, y_train)
# See if the model is reasonable.
print("Score: ", clf.score(X_test, y_test))
Есть ли причина, по которой один метод работает лучше, чем другой?
Вот как выглядит ввод данных (это набор данных Iris):
training_data
sepal_length sepal_width petal_length petal_width class
4.9 3.0 1.4 0.2 Iris-setosa
4.7 3.2 1.3 0.2 Iris-setosa
(...)