Spark ML: Как DecisionTreeClassificatonModel узнает о весах деревьев? - PullRequest
0 голосов
/ 26 февраля 2019

Я бы хотел получить вес для узлов дерева из сохраненного (или несохраненного) DecisionTreeClassificationModel.Однако я не могу найти ничего отдаленно напоминающего это.

Как модель фактически выполняет классификацию, не зная ни одного из них.Ниже приведены параметры, которые сохраняются в модели:

{"class":"org.apache.spark.ml.classification.DecisionTreeClassificationModel"
"timestamp":1551207582648
"sparkVersion":"2.3.2"
"uid":"DecisionTreeClassifier_4ffc94d20f1ddb29f282"
"paramMap":{
"cacheNodeIds":false
"maxBins":32
"minInstancesPerNode":1
"predictionCol":"prediction"
"minInfoGain":0.0
"rawPredictionCol":"rawPrediction"
"featuresCol":"features"
"probabilityCol":"probability"
"checkpointInterval":10
"seed":956191873026065186
"impurity":"gini"
"maxMemoryInMB":256
"maxDepth":2
"labelCol":"indexed"
}
"numFeatures":1
"numClasses":2
}

1 Ответ

0 голосов
/ 27 февраля 2019

Используя treeWeights:

treeWeights

Возвращает веса для каждого дерева

Новое в версии 1.5.0.

Итак

Как модель фактически выполняет классификацию, не зная ни одного из них.

Веса сохраняются, простоне как часть метаданных.Если у вас есть model

from pyspark.ml.classification import RandomForestClassificationModel

model: RandomForestClassificationModel = ...

и сохраните его на диск

path: str = ...

model.save(path)

, вы увидите, что программа записи создает подкаталог treesMetadata.Если вы загрузите содержимое (по умолчанию средство записи использует Parquet):

import os

trees_metadata = spark.read.parquet(os.path.join(path, "treesMetadata"))

вы увидите следующую структуру:

trees_metadata.printSchema()
root
 |-- treeID: integer (nullable = true)
 |-- metadata: string (nullable = true)
 |-- weights: double (nullable = true)

, где weights столбец содержит вес дереваидентифицируется treeID.

Аналогично данные узла хранятся в подкаталоге data (см., например, Извлечение и визуализация деревьев моделей из Sparklyr ):

spark.read.parquet(os.path.join(path, "data")).printSchema()     
root
 |-- id: integer (nullable = true)
 |-- prediction: double (nullable = true)
 |-- impurity: double (nullable = true)
 |-- impurityStats: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- gain: double (nullable = true)
 |-- leftChild: integer (nullable = true)
 |-- rightChild: integer (nullable = true)
 |-- split: struct (nullable = true)
 |    |-- featureIndex: integer (nullable = true)
 |    |-- leftCategoriesOrThreshold: array (nullable = true)
 |    |    |-- element: double (containsNull = true)
 |    |-- numCategories: integer (nullable = true)

Эквивалентная информация (за исключением данных дерева и весов дерева) доступна также для DecisionTreeClassificationModel.

...