Я хотел бы построить классификатор Gradient boosted tree
по PySpark
для задачи классификации мультикласса.Я пытался:
gb = GBTClassifier(maxIter=10)
ovr = OneVsRest(classifier=gb)
ovrModel = ovr.fit(trainingData)
gb_predictions = ovrModel.transform(valData)
evaluator = MulticlassClassificationEvaluator(metricName="accuracy")
gb_accuracy = evaluator.evaluate(gb_predictions)
Когда я запускаю код выше, я получаю эту ошибку:
numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
AssertionError: Classifier <class 'pyspark.ml.classification.GBTClassifier'> doesn't extend from HasRawPredictionCol.
это примерно строка ovrModel = ovr.fit(trainingData)
, но я не понимаю, чтонеправильно с данными тренировки.