Набор меток регрессии xgboost пуст - PullRequest
0 голосов
/ 13 апреля 2019

Я использую версию Scala xgboost4j для обучения линейной модели и получаю /xgboost/src/objective/regression_obj.cu:64: Check failed: info.labels_.Size() != 0U (0 vs. 0) label set cannot be empty error

Я пытался использовать dtrain.getLabel, у него есть метки действительного значения

нижемой фрагмент кода

  def train(dtrain: DMatrix, dtest: DMatrix, round: Int, earlyStoppingRound: Int = EARLY_STOPPING_ROUND): Booster = {
    val params = initParams
    val watches = initWatches(dtrain, dtest)
    XGBoost.train(dtrain, params, round, watches, earlyStoppingRound = earlyStoppingRound)
  }

  def initParams: Map[String, Any] = {
    val params = new mutable.HashMap[String, Any]()
    // General Parameters
    params += "booster" -> "gbtree"
    params += "verbosity" -> 3
    params += "nthread" -> 16
    // Parameters for Tree Booster
    params += "eta" -> 0.08 // learning rate
    params += "gamma" -> 0
    params += "max_depth" -> 6
    params += "min_child_weight" -> 5
    params += "subsample" -> 0.75
    params += "colsample_bylevel" -> 0.5
    params += "alpha" -> 2.0
    params += "lambda" -> 0.5
    params += "scale_pos_weight" -> 1
    // Learning Task Parameters
    params += "objective" -> "reg:linear"
    // params += "base_score" -> 0.5
    // params += "eval_metric" -> "rmse"
    params.toMap
  }

  def initWatches(dtrain: DMatrix, dtest: DMatrix): Map[String, DMatrix] = {
    val watches = new mutable.HashMap[String, DMatrix]
    watches += "train" -> dtrain
    watches += "test" -> dtest
    watches.toMap
  }

вот исключение, с которым я столкнулся,

Exception in thread "main" ml.dmlc.xgboost4j.java.XGBoostError: [10:14:41] /xgboost/src/objective/regression_obj.cu:64: Check failed: info.labels_.Size() != 0U (0 vs. 0) label set cannot be empty

Stack trace returned 6 entries:
[bt] (0) /tmp/libxgboost4j7469096669937497973.so(dmlc::StackTrace(unsigned long)+0x51) [0x7f4e2a489f21]
[bt] (1) /tmp/libxgboost4j7469096669937497973.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x1d) [0x7f4e2a48ad3d]
[bt] (2) /tmp/libxgboost4j7469096669937497973.so(xgboost::obj::RegLossObj<xgboost::obj::LinearSquareLoss>::GetGradient(xgboost::HostDeviceVector<float> const&, xgboost::MetaInfo const&, int, xgboost::HostDeviceVector<xgboost::detail::GradientPairInternal<float> >*)+0xd5) [0x7f4e2a5905e5]
[bt] (3) /tmp/libxgboost4j7469096669937497973.so(xgboost::LearnerImpl::UpdateOneIter(int, xgboost::DMatrix*)+0x37d) [0x7f4e2a51be8d]
[bt] (4) /tmp/libxgboost4j7469096669937497973.so(XGBoosterUpdateOneIter+0x35) [0x7f4e2a491955]
[bt] (5) [0x7f538fb12747]


    at ml.dmlc.xgboost4j.java.XGBoostJNI.checkCall(XGBoostJNI.java:48)
    at ml.dmlc.xgboost4j.java.Booster.update(Booster.java:127)
    at ml.dmlc.xgboost4j.java.XGBoost.train(XGBoost.java:190)
    at ml.dmlc.xgboost4j.scala.XGBoost$.train(XGBoost.scala:64)
    at com.zhihu.userprofile.pipeline.process.interests.KeywordInterestTrain.train(KeywordInterestTrain.scala:384)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...