Обучение XGBoostModel не удалось - PullRequest
1 голос
/ 06 марта 2020

У меня есть код Spark, написанный на python, в котором используется модель XGBoost для прогноза. Проблема, с которой я сталкиваюсь, состоит в том, что в коде есть «для l oop» для прогнозирования другого набора данных в l oop с использованием модели XGBoost и сохранения обученной модели.

Код работает нормально (без ошибок кодирования), но после запуска модели в 10-12 раз в л oop. Он просто выдает ошибку ниже и вылетает из моего приложения Spark. Это даже не работает с большим количеством повторов.

Некоторое устранение неполадок, что я сделал.

-> Это не проблема с данными. Если я перезапущу код с теми же данными, он не удастся, поскольку он будет успешно выполнен и завершится ошибкой после некоторой итерации в l oop

-> Я попытался увеличить ОЗУ, ядра и мониторинг ЦП / исполнителей во время выполнения. Я не вижу никаких проблем. Это не проблема ресурса.

Exception processing forecast: An error occurred while calling o17979.fit.
: ml.dmlc.xgboost4j.java.XGBoostError: XGBoostModel training failed
    at ml.dmlc.xgboost4j.scala.spark.XGBoost$.ml$dmlc$xgboost4j$scala$spark$XGBoost$$postTrackerReturnProcessing(XGBoost.scala:582)
    at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$2.apply(XGBoost.scala:459)
    at ml.dmlc.xgboost4j.scala.spark.XGBoost$$anonfun$trainDistributed$2.apply(XGBoost.scala:435)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
    at scala.collection.immutable.List.foreach(List.scala:392)
    at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
    at scala.collection.immutable.List.map(List.scala:296)
    at ml.dmlc.xgboost4j.scala.spark.XGBoost$.trainDistributed(XGBoost.scala:434)
    at ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor.train(XGBoostRegressor.scala:190)
    at ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor.train(XGBoostRegressor.scala:48)
    at org.apache.spark.ml.Predictor.fit(Predictor.scala:118)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)

Любые указатели на то же самое были бы хороши. Заранее спасибо.

Код

    #Loop through all clusters and train model
    for cluster in range(num_clusters):

        logger.info(f'Cluster : {cluster}')
        logger.info(f'Run Start time : {datetime.now()}')

        #SetModel Path and Model Name  
        model_path = 's3a://' + self.s3_bucket + '/'
        model_name = 'Forecast' + 'Cluster_' + str(cluster)             
        logger.info(f'Model Path : {model_path}')
        logger.info(f'Model Name : {model_name}')

        #Filter based on Cluster and cache
        forecast_vector_cluster_df = forecast_vector_df.where(f'cluster={cluster}')
        forecast_vector_cluster_df.cache()
        logger.info(f'Cluster Filter DF Count  : {forecast_vector_cluster_df.count()}')

        #Set Up param for XGBoost Model
        xgbRegressor = XGBoostRegressor(**self.model_parammap) \
                       .setFeaturesCol("features") \
                       .setLabelCol("predict") \
                       .setPredictionCol(f"prediction_{cluster}")

        logger.info('fitting model')
        xgboostModel = xgbRegressor.fit(forecast_vector_cluster_df)

        logger.info('saving model')
        xgboostModel.write().overwrite().save(model_path + model_name)

1 Ответ

0 голосов
/ 29 марта 2020

Проверьте версию java, установленную на вашем компьютере

Переход на java1.8 может решить проблему

...