Можно ли использовать модель LinearSVC с OneVsRest в PySpark? - PullRequest
0 голосов
/ 10 октября 2019

Я пытаюсь использовать модель LinearSVC в OneVsRest в PySpark, но кажется, что она еще не поддерживается.

Моя ошибка msg

LinearSVC only supports binary classification. 1 classes detected in LinearSVC_43a50b0b70d60a8cbdb1__labelCol

Какие изменения мне нужны для того, чтобыреализовать это в PySpark?

Кто-нибудь знает, когда OneVsRest в Pyspark будет поддерживать LinearSVC?

1 Ответ

0 голосов
/ 10 октября 2019

Сообщение об ошибке говорит о том, что ваш набор данных в настоящее время содержит только один класс, но LinearSVM - это двоичный алгоритм классификации, который требует ровно двух классов. Я не уверен, что остальная часть вашего кода вызовет какие-либо проблемы, потому что вы ничего не опубликовали. На случай, если вам или кому-то еще это понадобится, посмотрите ниже.

Как сказал alrady, LinearSVM - это алгоритм двоичной классификации, который никогда не будет поддерживать классификацию по нескольким классам по определению, но вы всегда можете уменьшить проблему классификации по нескольким классам. к проблеме бинарной классификации. One-vs-Rest - подход к такому сокращению. Он обучает по одному классификатору на класс, и с инженерной точки зрения имеет смысл разделить его на отдельный класс, такой как spark did . OneVsRest обучает один классификатор для каждого из ваших классов, и данный образец оценивается по этому списку классификаторов. Классификатор с наибольшим количеством баллов определяет прогнозируемую метку для вашего образца.

Посмотрите код ниже для использования OneVsRest с LinearSVC:

from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StringIndexer
from pyspark.ml.classification import OneVsRest, LinearSVC
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

df = spark.read.csv('/tmp/iris.data', schema='sepalLength DOUBLE, sepalWidth DOUBLE, petalLength DOUBLE, petalWidth DOUBLE, class STRING')


vecAssembler = VectorAssembler(inputCols=["sepalLength", "sepalWidth", "petalLength", 'petalWidth'], outputCol="features")
df = vecAssembler.transform(df)

stringIndexer = StringIndexer(inputCol="class", outputCol="label")
si_model = stringIndexer.fit(df)
df = si_model.transform(df)

svm = LinearSVC()
ovr = OneVsRest(classifier=svm)
ovrModel = ovr.fit(df)

evaluator = MulticlassClassificationEvaluator(metricName="accuracy")

predictions = ovrModel.transform(df)

print("Accuracy: {}".format(evaluator.evaluate(predictions)))

Вывод:

Accuracy: 0.9533333333333334
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...