Spark: конвейер FlatMap и CountVectorizer - PullRequest
2 голосов
/ 26 мая 2020

Я работаю над конвейером и пытаюсь разделить значение столбца перед передачей его в CountVectorizer.

Для этой цели я сделал собственный преобразователь.

class FlatMapTransformer(override val uid: String)
  extends Transformer {
  /**
   * Param for input column name.
   * @group param
   */
  final val inputCol = new Param[String](this, "inputCol", "The input column")
  final def getInputCol: String = $(inputCol)

  /**
   * Param for output column name.
   * @group param
   */
  final val outputCol = new Param[String](this, "outputCol", "The output column")
  final def getOutputCol: String = $(outputCol)

  def setInputCol(value: String): this.type = set(inputCol, value)
  def setOutputCol(value: String): this.type = set(outputCol, value)

  def this() = this(Identifiable.randomUID("FlatMapTransformer"))

  private val flatMap: String => Seq[String] = { input: String =>
    input.split(",")
  }

  override def copy(extra: ParamMap): SplitString = defaultCopy(extra)

  override def transform(dataset: Dataset[_]): DataFrame = {
    val flatMapUdf = udf(flatMap)
    dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
  }

  override def transformSchema(schema: StructType): StructType = {
    val dataType = schema($(inputCol)).dataType
    require(
      dataType.isInstanceOf[StringType],
      s"Input column must be of type StringType but got ${dataType}")
    val inputFields = schema.fields
    require(
      !inputFields.exists(_.name == $(outputCol)),
      s"Output column ${$(outputCol)} already exists.")

    DataTypes.createStructType(
      Array(
        DataTypes.createStructField($(outputCol), DataTypes.StringType, false)))
  }
}

Код кажется le git, но когда я пытаюсь связать его с другой операцией, возникает проблема. Вот мой конвейер:

val train = reader.readTrainingData()

val cat_features = getFeaturesByType(taskConfig, "categorical")
val num_features = getFeaturesByType(taskConfig, "numeric")
val cat_ohe_features = getFeaturesByType(taskConfig, "categorical", Some("ohe"))
val cat_features_string_index = cat_features.
  filter { feature: String => !cat_ohe_features.contains(feature) }

val catIndexer = cat_features_string_index.map {
  feature =>
    new StringIndexer()
      .setInputCol(feature)
      .setOutputCol(feature + "_index")
      .setHandleInvalid("keep")
}

    val flatMapper = cat_ohe_features.map {
      feature =>
        new FlatMapTransformer()
          .setInputCol(feature)
          .setOutputCol(feature + "_transformed")
    }

    val countVectorizer = cat_ohe_features.map {
      feature =>

        new CountVectorizer()
          .setInputCol(feature + "_transformed")
          .setOutputCol(feature + "_vectorized")
          .setVocabSize(10)
    }


// val countVectorizer = cat_ohe_features.map {
//   feature =>
//
//     val flatMapper = new FlatMapTransformer()
//       .setInputCol(feature)
//       .setOutputCol(feature + "_transformed")
// 
//     new CountVectorizer()
//       .setInputCol(flatMapper.getOutputCol)
//       .setOutputCol(feature + "_vectorized")
//       .setVocabSize(10)
// }

val cat_features_index = cat_features_string_index.map {
  (feature: String) => feature + "_index"
}

val count_vectorized_index = cat_ohe_features.map {
  (feature: String) => feature + "_vectorized"
}

val catFeatureAssembler = new VectorAssembler()
  .setInputCols(cat_features_index)
  .setOutputCol("cat_features")

val oheFeatureAssembler = new VectorAssembler()
  .setInputCols(count_vectorized_index)
  .setOutputCol("cat_ohe_features")

val numFeatureAssembler = new VectorAssembler()
  .setInputCols(num_features)
  .setOutputCol("num_features")

val featureAssembler = new VectorAssembler()
  .setInputCols(Array("cat_features", "num_features", "cat_ohe_features_vectorized"))
  .setOutputCol("features")

val pipelineStages = catIndexer ++ flatMapper ++ countVectorizer ++
  Array(
    catFeatureAssembler,
    oheFeatureAssembler,
    numFeatureAssembler,
    featureAssembler)

val pipeline = new Pipeline().setStages(pipelineStages)
pipeline.fit(dataset = train)

Запуская этот код, я получаю сообщение об ошибке: java.lang.IllegalArgumentException: Field "my_ohe_field_trasformed" does not exist.

[info]  java.lang.IllegalArgumentException: Field "from_expdelv_areas_transformed" does not exist.

[info]  at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
[info]  at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)

[info]  at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)

[info]  at scala.collection.AbstractMap.getOrElse(Map.scala:59)

[info]  at org.apache.spark.sql.types.StructType.apply(StructType.scala:265)

[info]  at org.apache.spark.ml.util.SchemaUtils$.checkColumnTypes(SchemaUtils.scala:56)

[info]  at org.apache.spark.ml.feature.CountVectorizerParams$class.validateAndTransformSchema(CountVectorizer.scala:75)

[info]  at org.apache.spark.ml.feature.CountVectorizer.validateAndTransformSchema(CountVectorizer.scala:123)

[info]  at org.apache.spark.ml.feature.CountVectorizer.transformSchema(CountVectorizer.scala:188)

Когда я раскомментирую stringSplitter и countVectorizer, возникает ошибка. в моем Transformer

java.lang.IllegalArgumentException: Field "my_ohe_field" does not exist. at val dataType = schema($(inputCol)).dataType

Результат звонка pipeline.getStages:

strIdx_3c2630a738f0

strIdx_0d76d55d4200

FlatMapTransformer_fd8595c2969c

FlatMapTransformer_2e9a7af0b0fa

cntVec_c2ef31f00181

cntVec_68a78eca06c9

vecAssembler_a81dd9f43d56

vecAssembler_b647d348f0a0

vecAssembler_b5065a22d5c8

vecAssembler_d9176b8bb593

Я могу пойти по неправильному пути. Любые комментарии приветствуются.

1 Ответ

2 голосов
/ 26 мая 2020

Ваш FlatMapTransformer #transform неверен, ваш вид отбрасывания / игнорирования всех других столбцов, когда вы выбираете только outputCol

, пожалуйста, измените свой метод на -

 override def transform(dataset: Dataset[_]): DataFrame = {
     val flatMapUdf = udf(flatMap)
    dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
  }

Кроме того, Измените свой transformSchema, чтобы сначала проверить столбец ввода, прежде чем проверять его тип данных -

 override def transformSchema(schema: StructType): StructType = {
require(schema.names.contains($(inputCol)), "inputCOl is not there in the input dataframe")
//... rest as it is
}

Обновление-1 на основе комментариев

  1. Пожалуйста, измените метод copy (хотя это не причина исключения, с которым вы столкнулись) -
override def copy(extra: ParamMap): FlatMapTransformer = defaultCopy(extra)
обратите внимание, что CountVectorizer принимает столбец, имеющий столбцы типа ArrayType(StringType, true/false), и поскольку выходные столбцы FlatMapTransformer становятся входом CountVectorizer, вам необходимо убедиться, что выходной столбец FlatMapTransformer должен иметь ArrayType(StringType, true/false). Я думаю, что это не так, ваш код сегодня выглядит следующим образом:
  override def transform(dataset: Dataset[_]): DataFrame = {
    val flatMapUdf = udf(flatMap)
    dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
  }

Функции explode преобразуют array<string> в string, поэтому выходной сигнал трансформатора становится StringType. вы можете изменить этот код на -

  override def transform(dataset: Dataset[_]): DataFrame = {
    val flatMapUdf = udf(flatMap)
    dataset.withColumn($(outputCol), flatMapUdf(col($(inputCol))))
  }

изменить transformSchema метод вывода ArrayType(StringType)
 override def transformSchema(schema: StructType): StructType = {
      val dataType = schema($(inputCol)).dataType
      require(
        dataType.isInstanceOf[StringType],
        s"Input column must be of type StringType but got ${dataType}")
      val inputFields = schema.fields
      require(
        !inputFields.exists(_.name == $(outputCol)),
        s"Output column ${$(outputCol)} already exists.")

      schema.add($(outputCol), ArrayType(StringType))
    }
изменить векторный ассемблер на этот -
val featureAssembler = new VectorAssembler()
      .setInputCols(Array("cat_features", "num_features", "cat_ohe_features"))
      .setOutputCol("features")

Я попытался выполнить ваш конвейер на фиктивном фреймворке данных, он работал хорошо. Полный код см. в этой сущности .

...