В CrossValidator с оценщиком конвейера. Как данные проходят через конвейер? - PullRequest
0 голосов
/ 09 сентября 2018

Взгляните на ML Tuning: перекрестная проверка У меня есть некоторые сомнения по поводу того, как данные проходят через конвейер Spark.

Давайте представим, что я хочу обработать набор данных, подобный следующему:

| sqfeet | #rooms | neighbourhood | price |
|--------|--------|---------------|-------|
| 50     | 2      | GR4242        | 100   |
| 120    | 3      | GR4242        | 220   |
| 100    | 2      | FD0202        | 180   |

То, что я делал в других платформах ML, было:

  1. Предварительная обработка ВСЕ данные. Например, один снимок, кодирующий столбец neighbourhood.
  2. Разделение данных в поезде / тесте.
  3. Выполнить настройку гиперпараметра, используя CV на наборе поездов.
  4. Получите объективную метрику производительности модели, используя набор тестов.

Однако, используя код, который я связал выше, метод transform моего пользовательского Transformer называется 2 * числом сгибов CV * перекрестного произведения сетки параметров.

Тестовая программа:

import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}


// Don't worry, the parallelism is set to 1
object Counter {
  var timesCalled = 0
}

class CustomTransformer(override val uid: String = Identifiable.randomUID("custom"))
  extends Transformer {

  override def transform(dataset: Dataset[_]): DataFrame = {
    println(s"Times called ${Counter.timesCalled}. Dataset passed:")
    dataset.show()
    Counter.timesCalled += 1
    dataset.toDF()
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = schema
}

object TestApp {

  def main(args: Array[String]): Unit = {

    lazy val spark: SparkSession =
      SparkSession
        .builder()
        .appName("testApp")
        .master("local[1]")
        .getOrCreate()

    // Prepare training data from a list of (id, text, label) tuples.
    val training = spark.createDataFrame(Seq(
      (50f, 2, "GR4242", 100f),
      (120f, 3, "GR4242", 220f),
      (100f, 2, "FD0202", 180f)
    )).toDF("sqfeet", "#rooms", "neighbourhood", "label")

    //    val stringIndexer = new StringIndexer()
    //      .setInputCol("neighbourhood")
    //      .setOutputCol("neighbourhood_index")
    val assembler = new VectorAssembler()
      .setInputCols(Array("sqfeet", "#rooms"))
      .setOutputCol("features")
    val customTransformer = new CustomTransformer()
    val lr = new LogisticRegression()
      .setMaxIter(10)
    val pipeline = new Pipeline()
      .setStages(Array(assembler, customTransformer, lr))

    val paramGrid = new ParamGridBuilder()
      .addGrid(lr.regParam, Array(0.1, 0.01))
      .build()

    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(new BinaryClassificationEvaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(2)

    val cvModel = cv.fit(training)

  }

} 

Вывод:

Times called 0. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
| 100.0|     2|       FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+

Times called 1. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
|  50.0|     2|       GR4242|100.0| [50.0,2.0]|
| 120.0|     3|       GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+

Times called 2. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
| 100.0|     2|       FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+

Times called 3. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
|  50.0|     2|       GR4242|100.0| [50.0,2.0]|
| 120.0|     3|       GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+

Times called 4. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
|  50.0|     2|       GR4242|100.0| [50.0,2.0]|
| 120.0|     3|       GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+

Times called 5. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
| 100.0|     2|       FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+

Times called 6. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
|  50.0|     2|       GR4242|100.0| [50.0,2.0]|
| 120.0|     3|       GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+

Times called 7. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
| 100.0|     2|       FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+

Times called 8. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label|   features|
+------+------+-------------+-----+-----------+
|  50.0|     2|       GR4242|100.0| [50.0,2.0]|
| 120.0|     3|       GR4242|220.0|[120.0,3.0]|
| 100.0|     2|       FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+

Я полагаю, нечетные числа представляют фазу обучения, а четные числа представляют фазу теста.

Не будет ли эффективнее выполнять все эти дорогостоящие вычисления предварительной обработки, которые выполняются в трансформаторах только один раз?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...