В Spark, как сделать One Hot Encoding только для первых N часто используемых значений? - PullRequest
0 голосов
/ 15 февраля 2020

Пусть в моем фрейме данных df у меня есть столбец my_category, в котором у меня есть разные значения, и я могу просмотреть счетчики значений, используя:

df.groupBy("my_category").count().show()

value   count
a    197
b    166
c    210
d      5
e      2
f      9
g      3

Теперь я хотел бы применить Один Hot Encoding (OHE) для этого столбца, но только для верхних N частых значений (скажем, N = 3), и поместите все остальные нечастые значения в фиктивный столбец (скажем, «по умолчанию»). Например, вывод должен выглядеть примерно так:

a  b  c  default
0  0  1  0
1  0  0  0
0  1  0  0
1  0  0  0
...
0  0  0  1
0  0  0  1
...

Как мне это сделать в Spark / Scala?

Примечание: я знаю, как это сделать в Python, например, сначала создав словарь на основе частоты для каждого уникального значения, а затем создайте вектор OHE, проверяя значения одно за другим, помещая нечастые в столбец «по умолчанию».

1 Ответ

1 голос
/ 15 февраля 2020

Пользовательская функция может быть написана так, чтобы применять One Hot Encoding (OHE) к определенному столбцу только для верхних N частых значений (скажем, N = 3).

Это относительно похоже на Python, 1) Построение верхнего n часто встречающегося Dataframe / Dictionary. 2) Поверните верхний n частый кадр данных, т.е. создайте вектор OHE. 3) Слева соединить данный Dataframe и развернуть Dataframe, заменить ноль на 0, то есть вектор OHE по умолчанию.

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, lit, when}
import org.apache.spark.sql.Column

import spark.implicits._
val df = spark
  .sparkContext
  .parallelize(Seq("a", "b", "c", "a", "b", "c", "d", "e", "a", "b", "f", "a", "g", "a", "b", "c", "a", "d", "e", "f", "a", "b", "g", "b", "c", "f", "a", "b", "c"))
  .toDF("value")

val oheEncodedDF = oheEncoding(df, "value", 3)


def oheEncoding(df: DataFrame, colName: String, n: Int): DataFrame = {
  df.createOrReplaceTempView("data")
  val topNDF = spark.sql(s"select $colName, count(*) as count from data group by $colName order by count desc limit $n")

  val pivotTopNDF = topNDF
    .groupBy(colName)
    .pivot(colName)
    .count()
    .withColumn("default", lit(1))

  val joinedTopNDF = df.join(pivotTopNDF, Seq(colName), "left").drop(colName)

  val oheEncodedDF = joinedTopNDF
    .na.fill(0, joinedTopNDF.columns)
    .withColumn("default", flip(col("default")))

   oheEncodedDF
}

def flip(col: Column): Column = when(col === 1, lit(0)).otherwise(lit(1))
...