Spark MultilayerPerceptronClassifier Класс Вероятности - PullRequest
0 голосов
/ 06 февраля 2019

Я опытный программист Python, пытающийся перевести некоторый код Python в Spark для задачи классификации.Я впервые работаю в Spark / Scala.

В Python нейронные сети Keras / tenorflow и sci-kit Learn отлично справляются с многоклассовой классификацией, и я могу легко вернуть 3 самых вероятных класса вместе с вероятностями, которые являются ключевымик этому проекту.

В целом мне удалось переместить код в Spark (Scala), и я смог сгенерировать правильные прогнозы, но я не смог найти способ вернуть вероятности для самых предсказуемых классов изMultilayerPerceptronClassifier в MLlib.

Ближайшее решение, которое я нашел, было в этом сообщении: Как получить вероятности классификации из MultilayerPerceptronClassifier? Однако я не могу заставить работать решение в посте, потому что оно отсутствуетключевой фрагмент кода или я слишком новичок в Scala (возможно, последний), чтобы внести необходимые корректировки.

Кто-нибудь решил эту проблему?

Это текущие версии в моей среде.Версия Spark: 2.1.1 Версия Scala: 2.11.8

Спасибо за помощь,

RKB

1 Ответ

0 голосов
/ 06 февраля 2019

Если вы внимательно посмотрите на результаты MultilayerPerceptronClassificationModel.transform (model и test, как определено в примере конвейера в официальной документации )

val result = model.transform(test)

result.printSchema
root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)

вы увидите, что они содержат probability столбец.

Он хранится как o.a.s.ml.linalg.Vector столбец:

result.select($"probability").show(3, false)
+---------------------------------------------------+
|probability                                        |
+---------------------------------------------------+
|[2.630203838780848E-29,1.7323171642231641E-19,1.0] |
|[1.0,1.448487547623119E-121,4.530084532282489E-44] |
|[1.0,5.157808976162274E-122,2.5702890543589884E-44]|
+---------------------------------------------------+
only showing top 3 rows

и может быть доступен с использованием стандарта методы .

Эта функция доступна с версии Spark 2.3 ( SPARK-12664 Вероятность выставления, rawPrediction в MultilayerPerceptronClassificationModel ).

...