Использовать сохраненную модель двоичной классификации дерева решений Spark Mllib для прогнозирования новых данных - PullRequest
0 голосов
/ 30 октября 2018

Я использую Spark версии 2.2.0 и scala версии 2.11.8. Я создал и сохранил модель двоичной классификации дерева решений, используя следующий код:

package...
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SparkSession


object DecisionTreeClassification {

def main(args: Array[String]): Unit = {

val sparkSession = SparkSession.builder
  .master("local[*]")
  .appName("Decision Tree")
  .getOrCreate()
// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sparkSession.sparkContext, "path/to/file/xyz.txt")
// Split the data into training and test sets (20% held out for testing)
val splits = data.randomSplit(Array(0.8, 0.2))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  impurity, maxDepth, maxBins)

// Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction) 
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
println(s"Test Error = $testErr")
println(s"Learned classification tree model:\n ${model.toDebugString}")

// Save and load model
model.save(sparkSession.sparkContext, "target/tmp/myDecisionTreeClassificationModel")
val sameModel = DecisionTreeModel.load(sparkSession.sparkContext, "target/tmp/myDecisionTreeClassificationModel")
// $example off$

sparkSession.sparkContext.stop()
}  
}

Теперь я хочу предсказать метку (0 или 1) для новых данных, используя эту сохраненную модель. Я новичок в Spark, кто-нибудь, пожалуйста, дайте мне знать, как это сделать?

1 Ответ

0 голосов
/ 01 ноября 2018

Я нашел ответ на этот вопрос, поэтому я подумал, что должен поделиться им, если кто-то ищет ответ на аналогичный вопрос

Чтобы сделать прогноз для новых данных, просто добавьте несколько строк перед остановкой сеанса спарка:

 val newData = MLUtils.loadLibSVMFile(sparkSession.sparkContext, "path/to/file/abc.txt")

 val newDataPredictions = newData.map 
    { point =>
      val newPrediction = model.predict(point.features)
      (point.label, newPrediction)
    }
    newDataPredictions.foreach(f => println("Predicted label", f._2))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...