Я делаю 500 выборок из набора из 10000 строк данных просто для простоты.Пожалуйста, скопируйте и вставьте X и y в вашу IDE.
X =
array([ -8.93, -0.17, 1.47, -6.13, -4.06, -2.22, -2.11, -0.25,
0.25, 0.49, 1.7 , -0.77, 1.07, 5.61, -11.95, -3.8 ,
-3.42, -2.55, -2.44, -1.99, -1.7 , -0.98, -0.91, -0.91,
-0.25, 1.7 , 2.88, -6.9 , -4.07, -1.35, -0.33, 0.63,
0.98, -3.31, -2.61, -2.61, -2.17, -1.38, -0.77, -0.25,
-0.08, -1.2 , -3.1 , -1.07, -0.7 , -0.41, -0.33, 0.41,
0.77, 0.77, 1.14, 2.17, -7.92, -3.8 , -2.11, -2.06,
-1.2 , -1.14, 0. , 0.56, 1.47, -1.99, -0.17, 2.44,
-5.87, -3.74, -3.37, -2.88, -0.49, -0.25, -0.08, 0.33,
0.33, 0.84, 1.64, 2.06, 2.88, -4.58, -1.82, -1.2 ,
0.25, 0.25, 0.63, 2.61, -5.36, -1.47, -0.63, 0. ,
0.63, 1.99, 1.99, -10.44, -2.55, 0.33, -8.93, -5.87,
-5.1 , -2.78, -0.25, 1.47, 1.93, 2.17, -5.36, -5.1 ,
-3.48, -2.44, -2.06, -2.06, -1.82, -1.58, -1.58, -0.63,
-0.33, 0. , 0.17, -3.31, -0.25, -5.1 , -3.8 , -2.55,
-1.99, -1.7 , -0.98, -0.91, -0.63, -0.25, 0.77, 0.91,
0.91, -9.43, -8.42, -2.72, -2.55, -1.26, 0.7 , 0.77,
1.07, 1.47, 1.7 , -1.82, -1.47, 0.17, 1.26, -5.36,
-1.52, -1.47, -0.17, -3.48, -3.31, -2.06, -1.47, 0.17,
0.25, 1.7 , 2.5 , -9.94, -6.08, -5.87, -3.37, -2.44,
-2.17, -1.87, -0.98, -0.7 , -0.49, 0.41, 1.47, 2.28,
-14.95, -12.44, -6.39, -4.33, -3.8 , -2.72, -2.17, -1.2 ,
0.41, 0.77, 0.84, 2.51, -1.99, -1.7 , -1.47, -1.2 ,
0.49, 0.63, 0.84, 0.98, 1.14, 2.5 , -2.06, -1.26,
-0.33, 0.17, 4.58, -7.41, -5.87, 1.2 , 1.38, 1.58,
1.82, 1.99, -6.39, -2.78, -2.67, -1.87, -1.58, -1.47,
0.84, -10.44, -7.41, -3.05, -2.17, -1.07, -1.07, -0.91,
0.25, 1.82, 2.88, -6.9 , -1.47, 0.33, -8.42, -3.8 ,
-1.99, -1.47, -1.47, -0.56, 0.17, 0.17, 0.25, 0.56,
4.58, -3.48, -2.61, -2.44, -0.7 , 0.63, 1.47, 1.82,
-13.96, -9.43, -2.67, -1.38, -0.08, 0. , 1.82, 3.05,
-4.58, -3.31, -0.98, -0.91, -0.7 , 0.77, -0.7 , -0.33,
0.56, 1.58, 1.7 , 2.61, -4.84, -4.84, -4.32, -2.88,
-1.38, -0.98, -0.17, 0.17, 0.49, 2.44, 4.32, -3.48,
-3.05, 0.56, -8.42, -3.48, -2.61, -2.61, -2.06, -1.47,
-0.98, 0. , 0.08, 1.38, 1.93, -9.94, -2.72, -1.87,
-1.2 , -1.07, 1.58, 4.58, -6.64, -2.78, -0.77, -0.7 ,
-0.63, 0.49, 1.07, -8.93, -4.84, -1.7 , 1.76, 3.31,
-11.95, -3.16, -3.05, -1.82, -0.49, -0.41, 0.56, 1.58,
-13.96, -3.05, -2.78, -2.55, -1.7 , -1.38, -0.91, -0.33,
1.2 , 1.32, 1.47, -2.06, -1.82, -7.92, -6.33, -4.32,
-3.8 , -1.93, -1.52, -0.98, -0.49, -0.33, 0.7 , 1.52,
1.76, -8.93, -7.41, -2.88, -2.61, -2.33, -1.99, -1.82,
-1.64, -0.84, 1.07, 2.06, -3.96, -2.44, -1.58, 0. ,
-3.31, -2.61, -1.58, -0.25, 0.33, 0.56, 0.84, 1.07,
-1.58, -0.25, 1.35, -1.99, -1.7 , -1.47, -1.47, -0.84,
-0.7 , -0.56, -0.33, 0.56, 0.63, 1.32, 2.28, 2.28,
-2.72, -0.25, 0.41, -6.9 , -4.42, -4.32, -1.76, -1.2 ,
-1.14, -1.07, 0.56, 1.32, 1.52, -14.97, -7.41, -5.1 ,
-2.61, -1.93, -0.98, 0.17, 0.25, 0.41, -4.42, -2.61,
-0.91, -0.84, 2.39, -2.61, -1.32, 0.41, -6.9 , -5.61,
-4.06, -3.31, -1.47, -0.91, -0.7 , -0.63, 0.33, 1.38,
2.61, -2.29, 3.06, 4.44, -10.94, -4.32, -3.42, -2.17,
-1.7 , -1.47, -1.32, -1.07, -0.7 , 0. , 0.77, 1.07,
-3.31, -2.88, -2.61, -1.47, -1.38, -0.63, -0.49, 1.07,
1.52, -3.8 , -1.58, -0.91, -0.7 , 0.77, 3.42, -8.42,
-2.88, -1.76, -1.76, -0.63, -0.25, 0.49, 0.63, -6.9 ,
-4.06, -1.82, -1.76, -1.76, -1.38, -0.91, -0.7 , 0.17,
1.38, 1.47, 1.47, -11.95, -0.98, -0.56, -14.97, -9.43,
-8.93, -2.72, -2.61, -1.64, -1.32, -0.56, -0.49, 0.91,
1.2 , 1.47, -3.8 , -3.06, -2.51, -1.04, -0.33, -0.33,
-3.31, -3.16, -3.05, -2.61, -1.47, -1.07, 2.17, 3.1 ,
-2.61, -0.25, -3.85, -2.44])
y =
array([1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0,
1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1,
0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1,
1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1,
1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0,
0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1,
0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1,
1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0,
1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0,
1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1,
1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0,
1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1])
Инициализация и обучение:
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X, y)
Перекрестная проверка:
from sklearn.model_selection import cross_val_score
cross_val_score(model, X, y, cv=10, scoring='r2').mean()
-0,3339677563815496 (отрицательный R2?)
Чтобы убедиться, что он близок к истинному R2 модели.Я сделал это:
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=None, shuffle=False)
r2_score(y_test, model.predict_proba(X_test)[:,1], multioutput='variance_weighted')
0,32642659661798396
Этот R2 имеет больше смысла для соответствия модели, и похоже, что два R2 являютсяпросто +/- знак переключателя, но это не так.В моей модели, использующей гораздо большую выборку, кросс-значение R2 составляет -0,24, а тест R2 - 0,18.И когда я добавляю функцию, которая, как представляется, полезна для модели, тест R2 повышается, а кросс-значение R2 уменьшается
Кроме того, если вы переключите LogisticRegression на LinearRegression, кросс-значение R2 теперь будет положительным и будет близко к R2тестовое задание.Что вызывает эту проблему?