Я обучил модель RF (с 3 деревьями и глубиной 4) на тренировочном наборе из 15 примеров. Ниже изображение того, как выглядят три дерева. У меня есть два класса (скажем, 0 и 1).
Пороговые значения указаны в левой ветви, а число в кругах (например, 7, 3 - количество примеров, которые были <= threshold и> threshold для функции 2, т.е. f2).
Теперь, когда я пытаюсь применить модель к тестовому набору из 10 примеров, я не уверен, как вычисляется необработанный прогноз.
+-----+----+----+----------+-------------------------------------------------------------------
|prediction|features |rawPrediction|probability |
+-----+----+----+----------+-----------------------------------------------------------------------------------------------------------+-------------+---------------------------------------+
|1.0 |[0.07707524933080619,0.03383458646616541,0.017208413001912046,9.0,2.5768015000258258,0.0,-1.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0014143059186559938,0.0,0.6666666666666667,7.076533785087878E-4,0.0014163090128755495,0.9354143466934853,0.9333333333333333,0.875,0.938888892531395,7.0] |[1.0,2.0] |[0.3333333333333333,0.6666666666666666]|
Я прошел по следующим ссылкам, чтобы понять, но я не могу обдумать это.
https://forums.databricks.com/questions/14355/how-does-randomforestclassifier-compute-the-rawpre.html
https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
Я точно знаю, что это не так просто, как вы думаете. Например, насколько я понимаю, это не так - например, если два дерева предсказывают 0, а одно дерево предсказывает как 1, тогда необработанный прогноз будет [2, 1]. Это не тот случай, потому что, когда я обучаю модель на 500 примерах, я вижу предварительный прогноз для этого же примера [0.9552544653780279,2.0447455346219723].
Может кто-нибудь объяснить мне, как математически это вычисляется? Любая помощь будет оценена здесь, так как это своего рода c, и я хочу понять, как это работает. Еще раз спасибо большое заранее, и, пожалуйста, отправьте, если есть какая-либо другая информация, необходимая для решения этой проблемы.
Редактировать: Добавление данных из модели:
+------+-----------------------------------------------------------------------------------------------------+
|treeID|nodeData |
+------+-----------------------------------------------------------------------------------------------------+
|0 |[0, 0.0, 0.5, [9.0, 9.0], 0.19230769230769235, 1, 4, [2, [0.12519961673586713], -1]] |
|0 |[1, 0.0, 0.42603550295857984, [9.0, 4.0], 0.42603550295857984, 2, 3, [20, [0.39610389610389607], -1]]|
|0 |[2, 0.0, 0.0, [9.0, 0.0], -1.0, -1, -1, [-1, [], -1]] |
|0 |[3, 1.0, 0.0, [0.0, 4.0], -1.0, -1, -1, [-1, [], -1]] |
|0 |[4, 1.0, 0.0, [0.0, 5.0], -1.0, -1, -1, [-1, [], -1]] |
|1 |[0, 1.0, 0.4444444444444444, [5.0, 10.0], 0.4444444444444444, 1, 2, [4, [0.9789660448762616], -1]] |
|1 |[1, 1.0, 0.0, [0.0, 10.0], -1.0, -1, -1, [-1, [], -1]] |
|1 |[2, 0.0, 0.0, [5.0, 0.0], -1.0, -1, -1, [-1, [], -1]] |
|2 |[0, 0.0, 0.48, [3.0, 2.0], 0.48, 1, 2, [20, [0.3246753246753247], -1]] |
|2 |[1, 0.0, 0.0, [3.0, 0.0], -1.0, -1, -1, [-1, [], -1]] |
|2 |[2, 1.0, 0.0, [0.0, 2.0], -1.0, -1, -1, [-1, [], -1]] |
+------+-----------------------------------------------------------------------------------------------------+