Установка параметра scalePosWeight для модели Spark xgBoost в сетке CV - PullRequest
0 голосов
/ 29 июня 2018

Я пытаюсь настроить свою модель xgBoost на Spark, используя Scala. Моя сетка параметров XGb выглядит следующим образом:

val xgbParamGrid = (new ParamGridBuilder()
                .addGrid(xgb.maxDepth, Array(8, 16))
                .addGrid(xgb.minChildWeight, Array(0.5, 1, 2))
                .addGrid(xgb.alpha, Array(0.8, 0.9, 1))
                .addGrid(xgb.lambda, Array(0.8, 1, 2))
                .addGrid(xgb.scalePosWeight, Array(1, 5, 9))
                .addGrid(xgb.subSample, Array(0.5, 0.8, 1))
                .addGrid(xgb.eta, Array(0.01, 0.1, 0.3, 0.5))
                .build())

вызов кросс-валидатора выглядит следующим образом:

val evaluator = (new BinaryClassificationEvaluator()
                      .setLabelCol("label")
                      .setRawPredictionCol("prediction")
                      .setMetricName("areaUnderPR"))

    val cv = (new CrossValidator()
              .setEstimator(pipeline_model_xgb)
              .setEvaluator(evaluator)
              .setEstimatorParamMaps(xgbParamGrid)
              .setNumFolds(10))

    val xgb_model = cv.fit(train)

Я получаю следующую ошибку только для параметра scalePosWeight:

error: type mismatch;
found   : org.apache.spark.ml.param.DoubleParam
required: org.apache.spark.ml.param.Param[AnyVal]
Note: Double <: AnyVal (and org.apache.spark.ml.param.DoubleParam <:                      

    org.apache.spark.ml.param.Param[Double]), but class Param is invariant in type T.
You may wish to define T as +T instead. (SLS 4.5)
                              .addGrid(xgb.scalePosWeight, Array(1, 5, 9))
                                           ^

Исходя из моего поиска, сообщение "Вы можете определить вместо T T + +" является обычным, но я не уверен, как это исправить здесь. Спасибо за вашу помощь!

1 Ответ

0 голосов
/ 07 августа 2018

Я столкнулся с той же проблемой при установке массива для minChildWeight, и массив был составлен только для типов Int. Решение, которое сработало (как для scalePosWeight, так и для minChildWeight), заключается в передаче массива с плавающей точкой:

.addGrid(xgb.scalePosWeight, Array(1.0, 5.0, 9.0))
...