Я работаю над кастомным преобразователем в Spark 2.2.0.
Трансформаторы Spark наследуются от org.apache.spark.ml.Transformer
, и в приведенном ниже примере я хочу реализовать черты HasInputCol
и HasOutputCol
.
package optimizer.feature
import org.apache.spark.annotation.Since
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.sql.functions.{col, struct, udf}
import org.apache.spark.sql.types.{StringType, StructType}
class SplitString(override val uid: String)
extends Transformer with HasInputCol with HasOutputCol {
def this() = this(Identifiable.randomUID("HashEncoder"))
/** @group setParam */
@Since("1.4.0")
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
override def copy(extra: ParamMap): SplitString = defaultCopy(extra)
override def transform(dataset: Dataset[_]): DataFrame = {
val split = udf((s: String) => s.split(","))
val outputSchema = transformSchema(dataset.schema)
val metadata = outputSchema($(outputCol)).metadata
dataset.withColumn($(outputCol), split(dataset($(inputCol))))
}
override def transformSchema(schema: StructType): StructType = {
require(
schema($(inputCol)).dataType.isInstanceOf[StringType],
s"Input column must be of type StringType but got ${schema($(inputCol)).dataType}")
val inputFields = schema.fields
require(
!inputFields.exists(_.name == $(outputCol)),
s"Output column ${$(outputCol)} already exists.")
val attrGroup = new AttributeGroup($(outputCol))
StructType(schema.fields :+ attrGroup.toStructField())
}
}
При компиляции я получаю серию ошибок : trait HasInputCol in package shared cannot be accessed in package org.apache.spark.ml.param.shared
и trait HasOutputCol in package shared cannot be accessed in package org.apache.spark.ml.param.shared