Я разработал этот простой LogTransformer, расширив UnaryTransformer для применения преобразования журнала к столбцу age в DataFrame. Я могу применить этот трансформатор и включить его в качестве стадии конвейера и сохранить модель конвейера после обучения.
class LogTransformer(override val uid: String) extends UnaryTransformer[Int,
Double, LogTransformer] with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("logTransformer"))
override protected def createTransformFunc: Int => Double = (feature: Int) => {Math.log10(feature)}
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == DataTypes.IntegerType, s"Input type must be integer type but got $inputType.")
}
override protected def outputDataType: DataType = DataTypes.DoubleType
override def copy(extra: ParamMap): LogTransformer = defaultCopy(extra)
}
object LogTransformer extends DefaultParamsReadable[LogTransformer]
Но когда я читаю постоянную модель, я получаю следующее исключение.
val MODEL_PATH = "model/census_pipeline_model"
cvModel.bestModel.asInstanceOf[PipelineModel].write.overwrite.save(MODEL_PATH)
val same_pipeline_model = PipelineModel.load(MODEL_PATH)
exception in thread "main" java.lang.NoSuchMethodException: dsml.Census$LogTransformer$2.read()
at java.lang.Class.getMethod(Class.java:1786)
at org.apache.spark.ml.util.DefaultParamsReader$.loadParamsInstance(ReadWrite.scala:652)
at org.apache.spark.ml.Pipeline$SharedReadWrite$$anonfun$4.apply(Pipeline.scala:274)
at org.apache.spark.ml.Pipeline$SharedReadWrite$$anonfun$4.apply(Pipeline.scala:272)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:245)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:245)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:245)
at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
at org.apache.spark.ml.Pipeline$SharedReadWrite$.load(Pipeline.scala:272)
at org.apache.spark.ml.PipelineModel$PipelineModelReader.load(Pipeline.scala:348)
at org.apache.spark.ml.PipelineModel$PipelineModelReader.load(Pipeline.scala:342)
at org.apache.spark.ml.util.MLReadable$class.load(ReadWrite.scala:380)
at org.apache.spark.ml.PipelineModel$.load(Pipeline.scala:332)
at dsml.Census$.main(Census.scala:572)
at dsml.Census.main(Census.scala)
Любые указатели о том, как это исправить, были бы полезны. Спасибо.