sklearn.metrics.roc_curve () дает более короткий вывод, чем уникальные значения входных пробников - PullRequest
0 голосов
/ 26 февраля 2020

Я думаю, что количество пороговых значений, выводимых из функции sklearn.metrics.roc_curve(y_true, y_score), зависит от количества уникальных значений входных данных y_score, прогнозируемых вероятностей из модели.

Но когда я запускал модель на небольшом наборе данных с исходными значениями y_scores (без округления, имеющим сотни уникальных значений) roc_curve выдает только 8 порогов

print(len(y_tr1))
print(len(pd.Series(tr_probs).unique()))
print(pd.Series(tr_probs).unique()[:10])
fpr_tr, tpr_tr, thresholds_tr = roc_curve(y_true=y_tr1, y_score=tr_probs)
print(fpr_tr)
print(thresholds_tr)
>>>
761
761
[0.03697368 0.01084693 0.01446446 0.01041848 0.00291137 0.00524581
 0.06648957 0.00339727 0.04496222 0.00595941]
[0.         0.         0.         0.00145985 0.00145985 0.00437956
 0.00437956 1.        ]
[1.92762100e+00 9.27621004e-01 7.77746886e-01 7.49798873e-01
 7.46834446e-01 7.35994455e-01 7.26034969e-01 1.80392157e-03]

Тогда я не мог понять, почему, и попытался округлить на прогнозируемые пробники, каким-то образом длина выходных порогов меняется на количество цифр, которые хранят пробники.

for i in range(8):
    print(len(roc_curve(y_true=y_tr1, y_score=tr_probs.round(i))[0]))
>>>
3
11
81
272
78
16
10
10

Пример, приведенный ниже:

roc_curve(y_true=y_tr1, y_score=tr_probs.round(4))[2]
>>> array([1.9276e+00, 9.2760e-01, 8.9110e-01, 8.8940e-01, 8.5630e-01,
       8.5620e-01, 8.5330e-01, 8.5290e-01, 8.5040e-01, 8.4770e-01,
       8.3740e-01, 8.3640e-01, 7.7770e-01, 7.4980e-01, 7.4680e-01,
       7.3600e-01, 7.2600e-01, 5.8490e-01, 5.8370e-01, 5.7670e-01,
       5.7640e-01, 5.0900e-01, 5.0830e-01, 4.8980e-01, 4.8940e-01,
       4.7480e-01, 4.7460e-01, 4.7350e-01, 4.7190e-01, 4.2640e-01,
       4.2620e-01, 4.2370e-01, 4.2340e-01, 4.1160e-01, 4.1020e-01,
       3.7230e-01, 3.7140e-01, 3.6190e-01, 3.6130e-01, 3.2010e-01,
       3.2000e-01, 3.1800e-01, 3.1640e-01, 3.0100e-01, 3.0010e-01,
       2.9850e-01, 2.9830e-01, 2.9440e-01, 2.9250e-01, 2.8550e-01,
       2.8500e-01, 2.7400e-01, 2.7340e-01, 2.5470e-01, 2.5360e-01,
       2.2510e-01, 2.2340e-01, 2.1800e-01, 2.1540e-01, 1.9300e-01,
       1.9240e-01, 1.5010e-01, 1.5000e-01, 1.4850e-01, 1.4780e-01,
       2.7600e-02, 2.2900e-02, 2.1600e-02, 2.1400e-02, 1.1500e-02,
       1.0800e-02, 8.5000e-03, 8.3000e-03, 5.2000e-03, 4.5000e-03,
       3.7000e-03, 3.4000e-03, 1.8000e-03])

Может ли кто-нибудь помочь мне понять, почему и как с этим бороться? Спасибо!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...