Spark ML API для преобразования вектора в вероятности для многослойной классификации - PullRequest
0 голосов
/ 09 сентября 2018

Я немного новичок в Spark ML API. Я пытаюсь сделать мульти-ярлычную классификацию для 160 ярлыков, обучая 160 классификаторов (логистика или случайный лес и т. Д.). Когда я тренируюсь на наборе данных [LabeledPoint], мне трудно получить API, в котором я получаю вероятность для каждого класса для одного примера. Я читал на SO, что вы можете использовать API конвейера и получить вероятности, но для моего варианта использования это будет сложно, потому что мне придется повторить 160 RDD для моих функций оценки, получить вероятность для каждого класса, а затем сделать объединение, чтобы ранжировать классы по их вероятностям. Вместо этого я хочу иметь только одну копию оценочных функций, транслировать 160 моделей и затем делать прогнозы внутри функции карты. Я чувствую себя вынужденным реализовать это, но задаюсь вопросом, есть ли еще один удобный API в Spark, чтобы сделать то же самое для различных классификаторов, таких как Logistic / RF, который преобразует вектор, представляющий объекты, в вероятность его принадлежности к классу. Пожалуйста, дайте мне знать, если есть лучший способ приблизиться к классификации с несколькими метками в Spark.

РЕДАКТИРОВАТЬ: я пытался создать функцию для преобразования вектора в метку для случайного леса, но это очень раздражает, потому что теперь я должен клонировать большие куски обхода дерева в Spark, и почти везде я сталкивался с тупиками, потому что некоторые функции или переменная была частной или защищенной. Поправьте меня, если ошиблись, но если этот вариант использования еще не реализован, я думаю, что он, по крайней мере, вполне оправдан, потому что Scikit-learn уже имеет такие API для этого.

Спасибо

1 Ответ

0 голосов
/ 09 сентября 2018

Найдена виновная строка в коде Spark MLLib: https://github.com/apache/spark/blob/5ad644a4cefc20e4f198d614c59b8b0f75a228ba/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala#L224

Метод прогнозирования помечен как защищенный, но для таких вариантов использования он должен быть общедоступным.

Это было исправлено в версии 2.4, как показано здесь: https://github.com/apache/spark/blob/branch-2.4/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala

Так что обновление до версии 2.4 должно сработать ... хотя я не думаю, что 2.4 уже вышел, поэтомуэто вопрос ожидания.

РЕДАКТИРОВАТЬ: для людей, которые заинтересованы, очевидно, это не только полезно для прогнозирования с несколькими метками, было отмечено, что в латентности также есть улучшение в 3-4 раза для регулярной классификации / регрессиидля прогнозов единичных / небольших партий (подробности см. https://issues.apache.org/jira/browse/SPARK-16198).

...