Python: как интерпретировать и улучшить функциюgnast_proba () в RandomForest - PullRequest
0 голосов
/ 29 января 2020

Поэтому я использую sci-kit learns RandomForestClassifier для классификации данных астрономических источников по трем категориям. Чтобы сделать мой вопрос более простым, я использовал только два источника в моем наборе тестов и получил predict_prob() баллов с:

predictions = rf_model.predict(data_test)
probab =  rf_model.predict_proba(data_test)

print(probab)
print('True Classifications:', classif_test.values)
print('Predictions', predictions) 

, давая мне следующее:

[[0.29 0.69 0.02]
 [0.08 0.92 0.  ]]
True Classifications: ['HMXB' 'AGN']
Predictions ['HMXB' 'HMXB']

где класс заказ [AGN, HMXB, SNR]. Проблема в том, что одно из этих предсказаний неверно, а другое верно.

У меня есть несколько вопросов. (а) как мне узнать, какой predict_prob() балл соответствует неверному прогнозу? б) Что именно описывает predict_prob()? Насколько вероятно, что классификация модели верна или что-то еще? (б) Что означает высокий класс вероятности для класса, который приводит к неточному предсказанию? Мой набор данных слишком мал или есть способы улучшить прогнозируемые вероятности?

Итак, для моих данных у меня 46 HMXB, 17 AGN и 3 SNR. У каждого источника есть три атрибута. Я знаю, что это небольшой набор данных, но мне интересно, слишком ли он мал, чтобы RandomForest или другие алгоритмы машинного обучения давали точные результаты.

1 Ответ

3 голосов
/ 29 января 2020

Для вопроса (b) Что именно описываетести в рамках функцииprent_prob ()?
вести_прогноз () даст вероятность меток.
например, если у вас есть три метки ['A', 'B', 'C'], а предикат_prob () дает [0.29,0.69, .02], означает, что результат этих конкретных данных имеет вероятность 0,29 стать 'A', 0,69 вероятность быть 'B', 0,02 вероятность быть 'C'.

Для вопроса (а), как я могу сказать, какой предикат_проб () оценка соответствует неверному прогнозу?
Из опубликованных вами результатов

[[0.29 0.69 0.02]
 [0.08 0.92 0.  ]]
Predictions ['HMXB' 'HMXB']

Это ясно говорит о том, что второй элемент в каждом списке соответствует 'HMXB'. И две другие вероятности (первый элемент и последний элемент) нам нужно посмотреть на данные и сказать.

Да, данные у вас небольшие и довольно несбалансированные. Потому что у вас есть много сэмплов для «HMXB» по сравнению с двумя другими. Попытайтесь получить больше образцов для других этикеток.

...