Работа над конвейером для обучения модели. Я использую как числовые, так и категориальные функции. Он прекрасно работает отдельно, пока я не попытаюсь объединить различные типы функций.
private def pipeline: Pipeline = {
val cat_features = getFeaturesByType(this.taskConfig, "categorical")
val num_features = getFeaturesByType(this.taskConfig, "numeric")
val catIndexer = cat_features.map {
feature => new StringIndexer().setInputCol(feature).setOutputCol(feature + "_index")
}
val cat_features_index = cat_features.map {
(feature: String) => feature + "_index"
}
this.modelFeatures = cat_features_index ++ num_features
val catFeatureAssembler = new VectorAssembler()
.setInputCols(cat_features_index)
.setOutputCol("cat_features")
val numFeatureAssembler = new VectorAssembler()
.setInputCols(num_features)
.setOutputCol("num_features")
val featureAssembler = new VectorAssembler()
.setInputCols(Array("cat_features", "num_features"))
.setOutputCol("features")
val pipelineStages = catIndexer ++
Array(catFeatureAssembler, numFeatureAssembler, featureAssembler)
TaskLogger.INFO(pipelineStages.toList.toString())
new Pipeline().setStages(pipelineStages)
}
private def estimator(paramMap: Map[String, Any]): XGBoostEstimator = {
new XGBoostEstimator(paramMap)
.setFeaturesCol("features")
.setLabelCol("indexed_label")
}
Приложение отключается при вызове
val preprocessed = this.pipelineModelTransform(dataset)
this.estimatorModel = Some(xgbEstimator.fit(preprocessed))
SparkContext был закрыт
Exception in thread "main" java.lang.IllegalStateException: SparkContext has been shutdown
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2021)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2050)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2069)
at org.apache.spark.rdd.RDD$$anonfun$take$1.apply(RDD.scala:1354)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:362)
at org.apache.spark.rdd.RDD.take(RDD.scala:1327)
at org.apache.spark.rdd.RDD$$anonfun$first$1.apply(RDD.scala:1368)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:362)
at org.apache.spark.rdd.RDD.first(RDD.scala:1367)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$.ml$dmlc$xgboost4j$scala$spark$XGBoost$$postTrackerReturnProcessing(XGBoost.scala:393)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$4.apply(XGBoost.scala:356)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$4.apply(XGBoost.scala:337)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.immutable.List.foreach(List.scala:381)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.immutable.List.map(List.scala:285)
at ml.dmlc.xgboost4j.scala.spark.XGBoost$.trainDistributed(XGBoost.scala:336)
at ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator.train(XGBoostEstimator.scala:139)
at ml.dmlc.xgboost4j.scala.spark.XGBoostEstimator.train(XGBoostEstimator.scala:36)
at org.apache.spark.ml.Predictor.fit(Predictor.scala:118)
at jp.co.yahoo.k2.i2i.optimizer.model.XGBoostRanker.fit(XGBoostRanker.scala:67)
at jp.co.yahoo.k2.i2i.optimizer.Optimizer$.train(Optimizer.scala:151)
at jp.co.yahoo.k2.i2i.optimizer.Optimizer$.main(Optimizer.scala:78)
at jp.co.yahoo.k2.i2i.optimizer.Optimizer.main(Optimizer.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:782)
at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:180)
at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:205)
at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:119)
at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)
Я использую xgboost4j spark 0.72 и spark 2.20