Взгляните на ML Tuning: перекрестная проверка У меня есть некоторые сомнения по поводу того, как данные проходят через конвейер Spark.
Давайте представим, что я хочу обработать набор данных, подобный следующему:
| sqfeet | #rooms | neighbourhood | price |
|--------|--------|---------------|-------|
| 50 | 2 | GR4242 | 100 |
| 120 | 3 | GR4242 | 220 |
| 100 | 2 | FD0202 | 180 |
То, что я делал в других платформах ML, было:
- Предварительная обработка ВСЕ данные. Например, один снимок, кодирующий столбец
neighbourhood
.
- Разделение данных в поезде / тесте.
- Выполнить настройку гиперпараметра, используя CV на наборе поездов.
- Получите объективную метрику производительности модели, используя набор тестов.
Однако, используя код, который я связал выше, метод transform
моего пользовательского Transformer называется 2 * числом сгибов CV * перекрестного произведения сетки параметров.
Тестовая программа:
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
// Don't worry, the parallelism is set to 1
object Counter {
var timesCalled = 0
}
class CustomTransformer(override val uid: String = Identifiable.randomUID("custom"))
extends Transformer {
override def transform(dataset: Dataset[_]): DataFrame = {
println(s"Times called ${Counter.timesCalled}. Dataset passed:")
dataset.show()
Counter.timesCalled += 1
dataset.toDF()
}
override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
override def transformSchema(schema: StructType): StructType = schema
}
object TestApp {
def main(args: Array[String]): Unit = {
lazy val spark: SparkSession =
SparkSession
.builder()
.appName("testApp")
.master("local[1]")
.getOrCreate()
// Prepare training data from a list of (id, text, label) tuples.
val training = spark.createDataFrame(Seq(
(50f, 2, "GR4242", 100f),
(120f, 3, "GR4242", 220f),
(100f, 2, "FD0202", 180f)
)).toDF("sqfeet", "#rooms", "neighbourhood", "label")
// val stringIndexer = new StringIndexer()
// .setInputCol("neighbourhood")
// .setOutputCol("neighbourhood_index")
val assembler = new VectorAssembler()
.setInputCols(Array("sqfeet", "#rooms"))
.setOutputCol("features")
val customTransformer = new CustomTransformer()
val lr = new LogisticRegression()
.setMaxIter(10)
val pipeline = new Pipeline()
.setStages(Array(assembler, customTransformer, lr))
val paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 0.01))
.build()
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(2)
val cvModel = cv.fit(training)
}
}
Вывод:
Times called 0. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 100.0| 2| FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+
Times called 1. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 50.0| 2| GR4242|100.0| [50.0,2.0]|
| 120.0| 3| GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+
Times called 2. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 100.0| 2| FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+
Times called 3. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 50.0| 2| GR4242|100.0| [50.0,2.0]|
| 120.0| 3| GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+
Times called 4. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 50.0| 2| GR4242|100.0| [50.0,2.0]|
| 120.0| 3| GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+
Times called 5. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 100.0| 2| FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+
Times called 6. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 50.0| 2| GR4242|100.0| [50.0,2.0]|
| 120.0| 3| GR4242|220.0|[120.0,3.0]|
+------+------+-------------+-----+-----------+
Times called 7. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 100.0| 2| FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+
Times called 8. Dataset passed:
+------+------+-------------+-----+-----------+
|sqfeet|#rooms|neighbourhood|label| features|
+------+------+-------------+-----+-----------+
| 50.0| 2| GR4242|100.0| [50.0,2.0]|
| 120.0| 3| GR4242|220.0|[120.0,3.0]|
| 100.0| 2| FD0202|180.0|[100.0,2.0]|
+------+------+-------------+-----+-----------+
Я полагаю, нечетные числа представляют фазу обучения, а четные числа представляют фазу теста.
Не будет ли эффективнее выполнять все эти дорогостоящие вычисления предварительной обработки, которые выполняются в трансформаторах только один раз?