Несоответствие между локальной обученной и обученной Datapro c моделью Spark ML - PullRequest
1 голос
/ 27 мая 2020

Я обновляю Spark с версии 2.3.1 до версии 2.4.5. Я переобучаю модель со Spark 2.4.5 на Datapro c Google Cloud Platform, используя Datapro c image 1.4.27-debian9. Когда я загружаю модель, созданную Datapro c, на свой локальный компьютер, используя Spark 2.4.5 для проверки модели. К сожалению, я получаю следующее исключение:

20/05/27 08:36:35 INFO HadoopRDD: Input split: file:/Users/.../target/classes/model.ml/stages/1_gbtc_961a6ef213b2/metadata/part-00000:0+657
20/05/27 08:36:35 INFO HadoopRDD: Input split: file:/Users/.../target/classes/model.ml/stages/1_gbtc_961a6ef213b2/metadata/part-00000:0+657
Exception in thread "main" java.lang.IllegalArgumentException: gbtc_961a6ef213b2 parameter impurity given invalid value variance.

Код для загрузки модели довольно прост:

import org.apache.spark.ml.PipelineModel

object ModelLoad {
  def main(args: Array[String]): Unit = {
    val modelInputPath = getClass.getResource("/model.ml").getPath
    val model = PipelineModel.load(modelInputPath)
  }
}

Я следил за трассировкой стека, чтобы проверить метаданные модели 1_gbtc_961a6ef213b2/metadata/part-00000 файл и обнаружил следующее:

{
    "class": "org.apache.spark.ml.classification.GBTClassificationModel",
    "timestamp": 1590593177604,
    "sparkVersion": "2.4.5",
    "uid": "gbtc_961a6ef213b2",
    "paramMap": {
        "maxIter": 50
    },
    "defaultParamMap": {
        ...
        "impurity": "variance",
        ...
    },
    "numFeatures": 1,
    "numTrees": 50
}

Примесь установлена ​​на variance, но моя локальная искра 2.4.5 ожидает, что это будет gini. Для проверки работоспособности переобучил модель на своей локальной искре 2.4.5. Для impurity в файле метаданных модели установлено значение gini.

Итак, проверил спарк 2.4.5 setImpurity method в GBT Javado c. Там написано The impurity setting is ignored for GBT models. Individual trees are built using impurity "Variance.". Искра 2.4.5, используемая Datapro c, похоже, соответствует документации Apache Spark. Но Spark 2.4.5, который я использую из Maven Central, устанавливает значение impurity на gini.

Кто-нибудь знает, почему существует такое несоответствие между Spark 2.4.5 в Datapro c и Maven Central?

Я создал простой обучающий код для локального воспроизведения результата:

import java.nio.file.Paths

import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.{DataFrame, SparkSession}

object SimpleModelTraining {
  def main(args: Array[String]) {


    val currentRelativePath = Paths.get("")
    val save_file_location = currentRelativePath.toAbsolutePath.toString

    val spark = SparkSession.builder()
      .config("spark.driver.host", "127.0.0.1")
      .master("local")
      .appName("spark-test")
      .getOrCreate()

    val df: DataFrame = spark.createDataFrame(Seq(
      (0, 0),
      (1, 0),
      (1, 0),
      (0, 1),
      (0, 1),
      (0, 1),
      (0, 2),
      (0, 2),
      (0, 2),
      (0, 3),
      (0, 3),
      (0, 3),
      (1, 4),
      (1, 4),
      (1, 4)
    )).toDF("label", "category")

    val pipeline: Pipeline = new Pipeline().setStages(Array(
      new VectorAssembler().setInputCols(Array("category")).setOutputCol("features"),
      new GBTClassifier().setMaxIter(30)
    ))

    val pipelineModel: PipelineModel = pipeline.fit(df)
    pipelineModel.write.overwrite().save(s"$save_file_location/test_model.ml")
  }
}

Спасибо!

1 Ответ

1 голос
/ 28 мая 2020

Spark в Datapro c с обратным переносом исправление для SPARK-25959 , которое может вызвать несоответствие между вашими локально обученными и обученными Datapro c моделями машинного обучения.

...