У меня есть фрейм данных примерно так:
log.comb CDEM_TWI Gruber_Ruggedness dNBR TC_Change_Sexton_Rel \
0 8.714914 10.70240 0.626106 0.701591 -27.12220
1 6.501334 10.65650 1.146360 0.693891 -35.52890
2 8.946111 13.58910 1.146360 0.513136 7.00000
3 8.955151 9.85036 1.126980 0.673891 13.81380
4 7.751379 7.28264 0.000000 0.256136 10.06940
5 8.895197 8.36555 0.000000 0.506000 -27.61340
6 8.676571 12.92650 0.000000 0.600627 -44.48400
7 8.562267 12.76980 0.519255 0.747009 -29.84790
8 9.052766 11.81580 0.519255 0.808336 -29.00900
9 9.133744 9.42046 0.484616 0.604891 -18.53550
10 8.221441 9.53682 0.484616 0.817336 -21.39920
11 8.398913 12.32050 0.519255 0.814745 -18.12080
12 7.587468 11.08880 1.274430 0.590282 92.85710
13 7.983136 8.95073 1.274430 0.316000 -10.34480
14 9.044404 11.18440 0.698818 0.608600 -14.77000
15 8.370293 11.96980 0.687634 0.323000 -9.60452
16 7.938134 12.42380 0.709549 0.374027 36.53140
17 8.183456 12.73490 1.439180 0.679627 -12.94420
18 8.322246 9.61600 0.551689 0.642900 37.50000
19 7.934997 7.77564 0.519255 0.690936 -25.29880
20 9.049387 11.16000 0.519255 0.789064 -35.73880
21 8.071323 6.17036 0.432980 0.574355 -22.43590
22 6.418345 5.98927 0.432980 0.584991 4.34783
23 7.950516 5.49527 0.422882 0.689009 25.22520
24 6.355529 7.35982 0.432980 0.419045 -18.81920
25 8.043683 5.18300 0.763596 0.582555 50.56180
26 6.013468 5.34018 0.493781 0.241155 -3.01205
27 7.961675 5.43264 0.493781 0.421527 -21.72290
28 8.074614 11.94630 0.493781 0.451800 11.61620
29 8.370570 6.34100 0.492384 0.550127 -12.50000
Pct_Pima Sand._15cm
0 75.62120 44.6667
1 69.30690 41.8333
2 59.47490 41.8333
3 66.08800 41.5000
4 34.31250 39.6667
5 35.04750 39.2424
6 62.32120 41.6667
7 57.14320 43.3333
8 57.35020 43.3333
9 72.90980 41.0000
10 57.61790 38.8333
11 57.35020 39.8333
12 69.30690 47.8333
13 69.30690 47.3333
14 76.58910 42.8333
15 75.62120 45.3333
16 76.69440 41.7727
17 59.47090 37.8333
18 61.10130 42.8333
19 72.67650 38.1818
20 57.35020 40.6667
21 23.15380 48.0000
22 17.15050 51.5000
23 0.00000 47.5000
24 6.67001 58.0000
25 15.18050 54.8333
26 5.89344 49.0000
27 5.89344 49.1667
28 13.18900 48.5000
29 13.30450 49.0000
Я хочу запустить линейную модель через 10-кратную перекрестную проверку, повторенную 100 раз.
В Python я делаю это:
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import RepeatedKFold
from sklearn.metrics import r2_score
X = df[['CDEM_TWI', 'Gruber_Ruggedness', 'dNBR', 'TC_Change_Sexton_Rel', 'Pct_Pima', 'Sand._15cm']].copy()
y = df[['log.comb']].copy()
all_r2 = []
rskf = RepeatedKFold(n_splits=10, n_repeats=10, random_state=42)
for train_index, test_index in rskf.split(X, y):
X_train, X_test = X.iloc[train_index], X.iloc[test_index]
y_train, y_test = y.iloc[train_index], y.iloc[test_index]
lm = LinearRegression(fit_intercept = True)
lm.fit(X_train, y_train)
pred = lm.predict(X_test)
r2 = r2_score(y_test, pred)
all_r2.append(r2)
avg = np.mean(all_r2)
и здесь avg
возвращает -0.11
В РИ это делается:
library(caret)
library(klaR)
train_control <- trainControl(method="repeatedcv", number=10, repeats=10)
model <- train(log.comb~., data=df, trControl=train_control, method="lm")
и model
возвращает:
RMSE Rsquared MAE
0.7868838 0.6132806 0.7047198
Мне любопытно, почему эти результаты так несовместимы друг с другом?Я понимаю, что сгибы между двумя разными языками не одинаковы, но так как я повторяю это много раз, я не понимаю, почему числа не так похожи.
Я также пробовал поиск по вложенной сеткев sklearn, например, так:
inner_cv = KFold(n_splits=10, shuffle=True, random_state=10)
outer_cv = KFold(n_splits=10, shuffle=True, random_state=10)
param_grid = {'fit_intercept': [True, False],
'normalize': [True, False]}
# Non_nested parameter search and scoring
clf = GridSearchCV(estimator=LinearRegression(), param_grid = param_grid, cv=inner_cv)
clf.fit(X, y)
non_nested_score = clf.best_score_
# Pass the gridSearch estimator to cross_val_score
clf = GridSearchCV(estimator= LinearRegression(), param_grid = param_grid, cv=inner_cv)
nested_score = cross_val_score(clf, X=X, y=y, cv=outer_cv).mean()
, но и nested_score
, и non_nested_score
оба по-прежнему отрицательны.