Во-первых, вам нужно использовать параметры с ParamGridBuilder
, а не сеттерами .
Во-вторых, ваши параметры должны быть переданы как double .
Таким образом, у вас будет что-то вроде:
import org.apache.spark.ml.feature.CountVectorizer
import org.apache.spark.ml.tuning.ParamGridBuilder
val countVectorizer = new CountVectorizer().setInputCol("subject").setOutputCol("features")
val paramGrid = new ParamGridBuilder().addGrid(countVectorizer.minTF, Array(1.0,3.0,5.0,7.0,9.0)).build()
// paramGrid: Array[org.apache.spark.ml.param.ParamMap] =
// Array({
// cntVec_4eab680c176c-minTF: 1.0
// }, {
// cntVec_4eab680c176c-minTF: 3.0
// }, {
// cntVec_4eab680c176c-minTF: 5.0
// }, {
// cntVec_4eab680c176c-minTF: 7.0
// }, {
// cntVec_4eab680c176c-minTF: 9.0
// })
РЕДАКТИРОВАТЬ:
Я не могу воспроизвести вашиошибка, но я заметил других.Я прокомментировал их в коде вместе с решением.
// organize imports
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{CountVectorizer, StringIndexer, Tokenizer}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SparkSession}
// Create a SparkSession if needed.
val spark = SparkSession.builder().getOrCreate()
// import implicits
import spark.implicits._
// I have created some toy data.
val data: DataFrame = Seq(
("CATEGORY_SOCIAL", "8 popular Pins for you"),
("CATEGORY_PROMOTIONS", "Want to plan with Jira and design in UXPin?"),
("CATEGORY_PROMOTIONS", "Test our new service today"),
("CATEGORY_PROMOTIONS", "deliveries on sunday"),
("CATEGORY_SOCIAL", "Twitter - your friends are missing you")
).toDF("labelss", "subjects")
// The tokenizer is ok even thought columns name wise, it can get confusing
val tokenizer: Tokenizer = new Tokenizer().
setInputCol("subjects").
setOutputCol("subject")
// Since we are creating a PipelineModel, it's always better
// to use the column from the previous stage
val countVectorizer: CountVectorizer = new CountVectorizer().
setInputCol(tokenizer.getOutputCol).
setOutputCol("features")
val labelIndexer: StringIndexer = new StringIndexer().
setInputCol("labelss").
setOutputCol("labelsss")
// Same comment here
val logisticRegression: LogisticRegression = new LogisticRegression().setLabelCol(labelIndexer.getOutputCol)
val pipeline: Pipeline = new Pipeline().setStages(Array(tokenizer, countVectorizer, labelIndexer, logisticRegression))
val paramGrid: Array[ParamMap] = new ParamGridBuilder().
addGrid(countVectorizer.minTF, Array(1.0, 3.0, 5.0)).
addGrid(logisticRegression.regParam, Array(0.1, 0.01)).
build()
// This works well. Result :
// paramGrid: Array[org.apache.spark.ml.param.ParamMap] =
// Array({
// cntVec_de795141d282-minTF: 1.0,
// logreg_fe22d7731a7e-regParam: 0.1
// }, {
// cntVec_de795141d282-minTF: 3.0,
// logreg_fe22d7731a7e-regParam: 0.1
// }, {
// cntVec_de795141d282-minTF: 5.0,
// logreg_fe22d7731a7e-regParam: 0.1
// }, {
// cntVec_de795141d282-minTF: 1.0,
// logreg_fe22d7731a7e-regParam: 0.01
// }, {
// cntVec_de795141d282-minTF: 3.0,
// logreg_fe22d7731a7e-regParam: 0.01
// }, {
// cntVec_de795141d282-minTF: 5.0,
// logreg_fe22d7731a7e-regParam: 0.01
// })
// Here is the trick, if you don't set your evaluator
// with the label you need to use explicitly, you'll end up
// getting an error since your are not using the default
// label column name value
// Something like : Caused by: java.lang.IllegalArgumentException: Field "label" does not exist.
val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelIndexer.getOutputCol)
// evaluator: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_c9d72a485d1d
val cv: CrossValidator = new CrossValidator().
setEstimator(pipeline).
setEvaluator(evaluator).
setEstimatorParamMaps(paramGrid).
setNumFolds(10). // Use 3+ in practice
setParallelism(2). // Evaluate up to 2 parameter settings in parallel
setSeed(123) // random seed
// cv: org.apache.spark.ml.tuning.CrossValidator = cv_2e1c55435a49
val model: CrossValidatorModel = cv.fit(data)
// model: org.apache.spark.ml.tuning.CrossValidatorModel = cv_2e1c55435a49
val result: DataFrame = model.transform(data)
// result: org.apache.spark.sql.DataFrame = [labelss: string, subjects: string ... 6 more fields]
result.show
// +-------------------+--------------------+--------------------+--------------------+--------+--------------------+--------------------+----------+
// | labelss| subjects| subject| features|labelsss| rawPrediction| probability|prediction|
// +-------------------+--------------------+--------------------+--------------------+--------+--------------------+--------------------+----------+
// | CATEGORY_SOCIAL|8 popular Pins fo...|[8, popular, pins...|(28,[0,8,16,21,25...| 1.0|[-2.5645425270090...|[0.07145555978623...| 1.0|
// |CATEGORY_PROMOTIONS|Want to plan with...|[want, to, plan, ...|(28,[1,6,9,17,18,...| 0.0|[3.57523120417979...|[0.97275417761670...| 0.0|
// |CATEGORY_PROMOTIONS|Test our new serv...|[test, our, new, ...|(28,[3,4,10,12,20...| 0.0|[3.15934297459226...|[0.95927528667918...| 0.0|
// |CATEGORY_PROMOTIONS|deliveries on sunday|[deliveries, on, ...|(28,[5,22,26],[1....| 0.0|[2.81641463947790...|[0.94355642175747...| 0.0|
// | CATEGORY_SOCIAL|Twitter - your fr...|[twitter, -, your...|(28,[0,2,7,11,13,...| 1.0|[-2.8838332277996...|[0.05295855512212...| 1.0|
// +-------------------+--------------------+--------------------+--------------------+--------+--------------------+--------------------+----------+
Примечание: Я не разбивал свои данные только по практическим соображениям, у меня недостаточно данных дляразделить на