Как заставить Spark без ошибок соответствовать GBT? - PullRequest
0 голосов
/ 03 августа 2020

Я получил данные из https://www.kaggle.com/c/GiveMeSomeCredit и пытаюсь следовать примеру, чтобы соответствовать модели GBT с использованием Spark через pyspark

dataframe = (....) # data obtained by reading CSV

labelIndexer = StringIndexer(inputCol="SeriousDlq2yrs", outputCol="indexedLabel")
vectorAssembler = VectorAssembler(inputCols=dataframe.columns.remove("SeriousDlq2yrs"), outputCol="rawFeatures")
featureIndexer = VectorIndexer(inputCol="rawFeatures", outputCol="indexedFeatures", maxCategories=4)

(trainingData, testData) = dataframe.randomSplit([0.7, 0.3])

gbt = GBTClassifier(featuresCol="indexedFeatures", labelCol="indexedLabel")

paramGrid = ParamGridBuilder().addGrid(gbt.maxDepth, [5, 6]).addGrid(gbt.maxIter, [5, 10]).build()

# We define an evaluation metric.  This tells CrossValidator how well we are doing by comparing the true labels with predictions.
evaluator = MulticlassClassificationEvaluator(metricName="accuracy", labelCol=gbt.getLabelCol(), predictionCol=gbt.getPredictionCol())
# Declare the CrossValidator, which runs model tuning for us.
cv = CrossValidator(estimator=gbt, evaluator=evaluator, estimatorParamMaps=paramGrid)

model = Pipeline(stages=[vectorAssembler, featureIndexer, labelIndexer, gbt]).fit(trainingData)

Я думал, что это сработает, но Я получаю

*** py4j.protocol.Py4JJavaError: An error occurred while calling o417.transform.
: java.lang.NullPointerException
        at scala.collection.mutable.ArrayOps$ofRef$.length$extension(ArrayOps.scala:192)
        at scala.collection.mutable.ArrayOps$ofRef.length(ArrayOps.scala:192)
        at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:32)
        at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
        at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241)
        at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:186)
        at org.apache.spark.ml.feature.VectorAssembler.transformSchema(VectorAssembler.scala:161)
        at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:74)
        at org.apache.spark.ml.feature.VectorAssembler.transform(VectorAssembler.scala:86)
        at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
        at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.lang.reflect.Method.invoke(Method.java:498)
        at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
        at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
        at py4j.Gateway.invoke(Gateway.java:282)
        at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
        at py4j.commands.CallCommand.execute(CallCommand.java:79)
        at py4j.GatewayConnection.run(GatewayConnection.java:238)
        at java.lang.Thread.run(Thread.java:748)

, что очень непонятно. Как мне отладить это?

...