Я думаю, что количество пороговых значений, выводимых из функции sklearn.metrics.roc_curve(y_true, y_score)
, зависит от количества уникальных значений входных данных y_score, прогнозируемых вероятностей из модели.
Но когда я запускал модель на небольшом наборе данных с исходными значениями y_scores (без округления, имеющим сотни уникальных значений) roc_curve выдает только 8 порогов
print(len(y_tr1))
print(len(pd.Series(tr_probs).unique()))
print(pd.Series(tr_probs).unique()[:10])
fpr_tr, tpr_tr, thresholds_tr = roc_curve(y_true=y_tr1, y_score=tr_probs)
print(fpr_tr)
print(thresholds_tr)
>>>
761
761
[0.03697368 0.01084693 0.01446446 0.01041848 0.00291137 0.00524581
0.06648957 0.00339727 0.04496222 0.00595941]
[0. 0. 0. 0.00145985 0.00145985 0.00437956
0.00437956 1. ]
[1.92762100e+00 9.27621004e-01 7.77746886e-01 7.49798873e-01
7.46834446e-01 7.35994455e-01 7.26034969e-01 1.80392157e-03]
Тогда я не мог понять, почему, и попытался округлить на прогнозируемые пробники, каким-то образом длина выходных порогов меняется на количество цифр, которые хранят пробники.
for i in range(8):
print(len(roc_curve(y_true=y_tr1, y_score=tr_probs.round(i))[0]))
>>>
3
11
81
272
78
16
10
10
Пример, приведенный ниже:
roc_curve(y_true=y_tr1, y_score=tr_probs.round(4))[2]
>>> array([1.9276e+00, 9.2760e-01, 8.9110e-01, 8.8940e-01, 8.5630e-01,
8.5620e-01, 8.5330e-01, 8.5290e-01, 8.5040e-01, 8.4770e-01,
8.3740e-01, 8.3640e-01, 7.7770e-01, 7.4980e-01, 7.4680e-01,
7.3600e-01, 7.2600e-01, 5.8490e-01, 5.8370e-01, 5.7670e-01,
5.7640e-01, 5.0900e-01, 5.0830e-01, 4.8980e-01, 4.8940e-01,
4.7480e-01, 4.7460e-01, 4.7350e-01, 4.7190e-01, 4.2640e-01,
4.2620e-01, 4.2370e-01, 4.2340e-01, 4.1160e-01, 4.1020e-01,
3.7230e-01, 3.7140e-01, 3.6190e-01, 3.6130e-01, 3.2010e-01,
3.2000e-01, 3.1800e-01, 3.1640e-01, 3.0100e-01, 3.0010e-01,
2.9850e-01, 2.9830e-01, 2.9440e-01, 2.9250e-01, 2.8550e-01,
2.8500e-01, 2.7400e-01, 2.7340e-01, 2.5470e-01, 2.5360e-01,
2.2510e-01, 2.2340e-01, 2.1800e-01, 2.1540e-01, 1.9300e-01,
1.9240e-01, 1.5010e-01, 1.5000e-01, 1.4850e-01, 1.4780e-01,
2.7600e-02, 2.2900e-02, 2.1600e-02, 2.1400e-02, 1.1500e-02,
1.0800e-02, 8.5000e-03, 8.3000e-03, 5.2000e-03, 4.5000e-03,
3.7000e-03, 3.4000e-03, 1.8000e-03])
Может ли кто-нибудь помочь мне понять, почему и как с этим бороться? Спасибо!