Pyspark экстракт ROC кривой? - PullRequest
       20

Pyspark экстракт ROC кривой?

0 голосов
/ 17 октября 2018

Есть ли способ получить точки на кривой ROC от Spark ML в pyspark?В документации я вижу пример для Scala, но не для python: https://spark.apache.org/docs/2.1.0/mllib-evaluation-metrics.html

Это правильно?Я, конечно, могу думать о способах его реализации, но я должен представить, что это быстрее, если есть встроенная функция.Я работаю с 3 миллионами партитур и несколькими десятками моделей, поэтому скорость имеет значение.

Спасибо!

Ответы [ 2 ]

0 голосов
/ 04 августа 2019

Для более общего решения, которое работает для моделей помимо Логистической регрессии (например, Деревья решений или Случайный лес, в которых отсутствует сводка моделей), вы можете получить кривую ROC, используя BinaryClassificationMetrics от Spark MLlib.

Обратите внимание, что версия PySpark не реализует все методы, которые версия Scala реализует, поэтому вам нужно использовать функцию .call(name) из JavaModelWrapper .Также кажется, что py4j не поддерживает синтаксический анализ scala.Tuple2 классов, поэтому они должны обрабатываться вручную.

Пример:

from pyspark.mllib.evaluation import BinaryClassificationMetrics

# Scala version implements .roc() and .pr()
# Python: https://spark.apache.org/docs/latest/api/python/_modules/pyspark/mllib/common.html
# Scala: https://spark.apache.org/docs/latest/api/java/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.html
class CurveMetrics(BinaryClassificationMetrics):
    def __init__(self, *args):
        super(CurveMetrics, self).__init__(*args)

    def _to_list(self, rdd):
        points = []
        # Note this collect could be inefficient for large datasets 
        # considering there may be one probability per datapoint (at most)
        # The Scala version takes a numBins parameter, 
        # but it doesn't seem possible to pass this from Python to Java
        for row in rdd.collect():
            # Results are returned as type scala.Tuple2, 
            # which doesn't appear to have a py4j mapping
            points += [(float(row._1()), float(row._2()))]
        return points

    def get_curve(self, method):
        rdd = getattr(self._java_model, method)().toJavaRDD()
        return self._to_list(rdd)

Использование:

import matplotlib.pyplot as plt

# Create a Pipeline estimator and fit on train DF, predict on test DF
model = estimator.fit(train)
predictions = model.transform(test)

# Returns as a list (false positive rate, true positive rate)
preds = predictions.select('label','probability').rdd.map(lambda row: (float(row['probability'][1]), float(row['label'])))
roc = CurveMetrics(preds).get_curve('roc')

plt.figure()
x_val = [x[0] for x in points]
y_val = [x[1] for x in points]
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.plot(x_val, y_val)

ROC curve generated with PySpark BinaryClassificationMetrics

BinaryClassificationMetrics в Scala также реализует несколько других полезных методов:

metrics = CurveMetrics(preds)
metrics.get_curve('fMeasureByThreshold')
metrics.get_curve('precisionByThreshold')
metrics.get_curve('recallByThreshold')
0 голосов
/ 17 октября 2018

Поскольку кривая ROC представляет собой график зависимости FPR от TPR, вы можете извлечь необходимые значения следующим образом:

your_model.summary.roc.select('FPR').collect()
your_model.summary.roc.select('TPR').collect())

Где your_model может быть, например, моделью, полученной вами из чего-то вродеthis:

from pyspark.ml.classification import LogisticRegression
log_reg = LogisticRegression()
your_model = log_reg.fit(df)

Теперь вам нужно просто построить FPR против TPR, используя, например, matplotlib.

PS

Вот полныйпример построения кривой ROC с использованием модели с именем your_model (и всего остального!).Я также построил контрольную линию «случайного предположения» внутри графика ROC.

import matplotlib.pyplot as plt
plt.figure(figsize=(5,5))
plt.plot([0, 1], [0, 1], 'r--')
plt.plot(your_model.summary.roc.select('FPR').collect(),
         your_model.summary.roc.select('TPR').collect())
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.show()
...