Если вам конкретно необходимо сгенерировать кривые ROC для разных порогов, одним из подходов может быть создание списка пороговых значений, которые вас интересуют, и подгонка / преобразование в вашем наборе данных для каждого порога.Или вы можете вручную рассчитать ROC-кривую для каждой пороговой точки, используя поле probability
в ответе от model.transform(test)
.
В качестве альтернативы, вы можете использовать BinaryClassificationMetrics , чтобы извлечь кривую, отображающую различныеметрики (оценка F1, точность, отзыв) по порогу.
К сожалению, похоже, что версия PySpark не реализует большинство методов, которые делает версия Scala, поэтому вам нужно обернуть класс, чтобы сделать это в Python.
Например:
from pyspark.mllib.evaluation import BinaryClassificationMetrics
# Scala version implements .roc() and .pr()
# Python: https://spark.apache.org/docs/latest/api/python/_modules/pyspark/mllib/common.html
# Scala: https://spark.apache.org/docs/latest/api/java/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.html
class CurveMetrics(BinaryClassificationMetrics):
def __init__(self, *args):
super(CurveMetrics, self).__init__(*args)
def _to_list(self, rdd):
points = []
# Note this collect could be inefficient for large datasets
# considering there may be one probability per datapoint (at most)
# The Scala version takes a numBins parameter,
# but it doesn't seem possible to pass this from Python to Java
for row in rdd.collect():
# Results are returned as type scala.Tuple2,
# which doesn't appear to have a py4j mapping
points += [(float(row._1()), float(row._2()))]
return points
def get_curve(self, method):
rdd = getattr(self._java_model, method)().toJavaRDD()
return self._to_list(rdd)
Использование:
import matplotlib.pyplot as plt
preds = predictions.select('label','probability').rdd.map(lambda row: (float(row['probability'][1]), float(row['label'])))
# Returns as a list (false positive rate, true positive rate)
roc = CurveMetrics(preds).get_curve('roc')
plt.figure()
x_val = [x[0] for x in points]
y_val = [x[1] for x in points]
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.plot(x_val, y_val)
Результат:
Вот пример кривой оценки F1 по пороговому значению, если выне состоят в браке с РПЦ: