Ваш FlatMapTransformer #transform
неверен, ваш вид отбрасывания / игнорирования всех других столбцов, когда вы выбираете только outputCol
, пожалуйста, измените свой метод на -
override def transform(dataset: Dataset[_]): DataFrame = {
val flatMapUdf = udf(flatMap)
dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
}
Кроме того, Измените свой transformSchema
, чтобы сначала проверить столбец ввода, прежде чем проверять его тип данных -
override def transformSchema(schema: StructType): StructType = {
require(schema.names.contains($(inputCol)), "inputCOl is not there in the input dataframe")
//... rest as it is
}
Обновление-1 на основе комментариев
- Пожалуйста, измените метод
copy
(хотя это не причина исключения, с которым вы столкнулись) -
override def copy(extra: ParamMap): FlatMapTransformer = defaultCopy(extra)
обратите внимание, что
CountVectorizer
принимает столбец, имеющий столбцы типа
ArrayType(StringType, true/false)
, и поскольку выходные столбцы
FlatMapTransformer
становятся входом
CountVectorizer
, вам необходимо убедиться, что выходной столбец
FlatMapTransformer
должен иметь
ArrayType(StringType, true/false)
. Я думаю, что это не так, ваш код сегодня выглядит следующим образом:
override def transform(dataset: Dataset[_]): DataFrame = {
val flatMapUdf = udf(flatMap)
dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
}
Функции explode
преобразуют array<string>
в string
, поэтому выходной сигнал трансформатора становится StringType
. вы можете изменить этот код на -
override def transform(dataset: Dataset[_]): DataFrame = {
val flatMapUdf = udf(flatMap)
dataset.withColumn($(outputCol), flatMapUdf(col($(inputCol))))
}
изменить
transformSchema
метод вывода
ArrayType(StringType)
override def transformSchema(schema: StructType): StructType = {
val dataType = schema($(inputCol)).dataType
require(
dataType.isInstanceOf[StringType],
s"Input column must be of type StringType but got ${dataType}")
val inputFields = schema.fields
require(
!inputFields.exists(_.name == $(outputCol)),
s"Output column ${$(outputCol)} already exists.")
schema.add($(outputCol), ArrayType(StringType))
}
изменить векторный ассемблер на этот -
val featureAssembler = new VectorAssembler()
.setInputCols(Array("cat_features", "num_features", "cat_ohe_features"))
.setOutputCol("features")
Я попытался выполнить ваш конвейер на фиктивном фреймворке данных, он работал хорошо. Полный код см. в этой сущности .