Есть ли способ извлечь метку столбца для корневого узла в модели классификатора дерева решений Spark 2.1.3 mllib (Scala)? - PullRequest
0 голосов
/ 03 января 2019

В настоящее время я работаю над набором данных, который требует классификации дерева решений. Созданный мной DataFrame является SparseVector, использующим пример из mllib example . Я знаю, что есть способ извлечь примеси и предсказания для корневого узла, однако я хочу иметь возможность получить имя столбца этого узла.

Предположим, мой DataFrame выглядит следующим образом:

+-----|-------|-------|-------+
| id  | col_1 | col_2 | col_3 |
+-----|-------|-------|-------+
| 0   | 1.0   | 0.0   | 2.0   |
| 1   | 2.0   | 1.0   | 0.0   |
| 3   | 2.0   | 2.0   | 1.0   |
+-----|-------|-------|-------+

Мой последний набор данных содержит гораздо больше столбцов, чем этот, но только для того, чтобы показать пример того, с чем я работаю.

Затем я использую преобразование этого в VectorIndexer, используя пример кода, который даст мне что-то похожее на это:

+-----|--------------------------------+
| id  | features                       | 
+-----|--------------------------------+
| 0   | (3, [0, 2], [1.0, 2.0])        |
| 1   | (3, [0, 1]. [2.0, 1.0])        |
| 3   | (3, [0, 1, 2], [2.0, 2.0, 1.0] | 
+-----|--------------------------------+

В конце концов, у меня есть обученное дерево классификации, и я могу получить корневой узел из дерева. Но я хочу получить имя столбца, связанное с корневым узлом.

val data = spark.read.format("libsvm").load("./src/main/resources/sample_libsvm_data.txt")

val labelIndexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("indexedLabel")
  .fit(data)

val featureIndexer = new VectorIndexer()
  .setInputCol("features")
  .setOutputCol("indexedFeatures")
  .setMaxCategories(4) // features with > 4 distinct values are treated as continuous.
  .fit(data)

val (trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

val dt = new DecisionTreeClassifier()
  .setLabelCol("indexedLabel")
  .setFeaturesCol("indexedFeatures")

val labelConverter = new IndexToString()
  .setInputCol("prediction")
  .setOutputCol("predictedLabel")
  .setLabels(labelIndexer.labels)


val pipeline = new Pipeline()
  .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
println("Pipeline stages: ")
pipeline.getStages.foreach(println)

val model = pipeline.fit(trainingData)

val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
println("Learned classification tree model:\n" + treeModel.toDebugString)

val root = treeModel.rootNode
println("Root: " + root.toString)

Это даст мне что-то, например:

Trained classification tree model:
DecisionTreeClassificationModel (uid=dtc_5eadf281a5fc) of depth 4 with 13 nodes
  If (feature 32 in {1.0,2.0,3.0,4.0})
   If (feature 4 in {0.0})
    Predict: 2.0
   Else (feature 4 not in {0.0})
    Predict: 1.0
  Else (feature 32 not in {1.0,2.0,3.0,4.0})
   If (feature 37 in {0.0,2.0})
    If (feature 1 in {1.0})
     Predict: 1.0
    Else (feature 1 not in {1.0})
     If (feature 4 in {0.0,3.0})
      Predict: 0.0
     Else (feature 4 not in {0.0,3.0})
      Predict: 2.0
   Else (feature 37 not in {0.0,2.0})
    If (feature 3 in {0.0})
     Predict: 2.0
    Else (feature 3 not in {0.0})
     Predict: 1.0

Root: InternalNode(prediction = 0.0, impurity = 0.6428571428571429, split = org.apache.spark.ml.tree.CategoricalSplit@a0058edd)

Как вы можете видеть, я вижу корень из дерева feature 32, но я хочу иметь возможность получить имя столбца для него.

Заранее спасибо за любую помощь.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...