Spark collect_list и ограничить результирующий список - PullRequest
0 голосов
/ 23 сентября 2018

У меня есть фрейм данных следующего формата:

name          merged
key1    (internalKey1, value1)
key1    (internalKey2, value2)
...
key2    (internalKey3, value3)
...

Я хочу сгруппировать фрейм данных по name, собрать список и limit sizeсписок.

Вот как я группирую по name и собираю список:

val res = df.groupBy("name")
            .agg(collect_list(col("merged")).as("final"))

Соответствующий кадр данных выглядит примерно так:

 key1   [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list 
 key2   [(internalKey3, value3),...]

ЧтоЯ хочу сделать, это ограничить размер создаваемых списков для каждого ключа.Я пробовал несколько способов сделать это, но безуспешно.Я уже видел некоторые посты, которые предлагают сторонние решения, но я хочу избежать этого.Есть ли способ?

Ответы [ 3 ]

0 голосов
/ 07 ноября 2018

Таким образом, пока UDF делает то, что вам нужно, если вы ищете более производительный способ, который также чувствителен к памяти, способ сделать это - написать UDAF.К сожалению, API UDAF на самом деле не так расширяемо, как агрегатные функции, которые поставляются с искрой.Однако вы можете использовать их внутренние API для создания внутренних функций и делать то, что вам нужно.

Вот реализация для collect_list_limit, которая в основном является копией внутренней CollectList AggregateFunction в Spark.Я бы просто расширил его, но это класс дела.На самом деле все, что нужно, это переопределить методы update и merge для соблюдения переданного предела:

case class CollectListLimit(
    child: Expression,
    limitExp: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {

  val limit = limitExp.eval( null ).asInstanceOf[Int]

  def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

  override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
    if( buffer.size < limit ) super.update(buffer, input)
    else buffer
  }

  override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
    if( buffer.size >= limit ) buffer
    else if( other.size >= limit ) other
    else ( buffer ++= other ).take( limit )
  }

  override def prettyName: String = "collect_list_limit"
}

И чтобы реально зарегистрировать его, мы можем сделать это через внутренний Spark FunctionRegistry, который принимает имя истроитель, который фактически является функцией, которая создает CollectListLimit, используя предоставленные выражения:

val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )

Редактировать:

Оказывается, добавление его во встроенную функцию работает, только если вы не создалиSparkContext пока что делает неизменным клоном при запуске.Если у вас есть существующий контекст, то это должно работать, чтобы добавить его с отражением:

val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )
0 голосов
/ 05 мая 2019

Yo может использовать UDF

Вот вероятный пример без необходимости схемы и со значительным сокращением

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._

import scala.collection.mutable


object TestJob1 {

  def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._

val rawDf = Seq(
  ("key", 1L, "gargamel"),
  ("key", 4L, "pe_gadol"),
  ("key", 2L, "zaam"),
  ("key1", 5L, "naval")
).toDF("group", "quality", "other")

rawDf.show(false)
rawDf.printSchema

val rawSchema = rawDf.schema

val fUdf = udf(reduceByQuality, rawSchema)

val aggDf = rawDf
  .groupBy("group")
  .agg(
    count(struct("*")).as("num_reads"),
    max(col("quality")).as("quality"),
    collect_list(struct("*")).as("horizontal")
  )
  .withColumn("short", fUdf($"horizontal"))
  .drop("horizontal")


aggDf.printSchema

aggDf.show(false)
}

def reduceByQuality= (x: Any) => {

val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

val red = d.reduce((r1, r2) => {

  val quality1 = r1.getAs[Long]("quality")
  val quality2 = r2.getAs[Long]("quality")

  val r3 = quality1 match {
    case a if a >= quality2 =>
      r1
    case _ =>
      r2
  }

  r3
})

red
}
}

вот пример с данными, подобными вашим

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.functions._

import scala.collection.mutable


object TestJob {

  def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._


val df1 = Seq(
  ("key1", ("internalKey1", "value1")),
  ("key1", ("internalKey2", "value2")),
  ("key2", ("internalKey3", "value3")),
  ("key2", ("internalKey4", "value4")),
  ("key2", ("internalKey5", "value5"))
)
  .toDF("name", "merged")

//    df1.printSchema
//
//    df1.show(false)

val res = df1
  .groupBy("name")
  .agg( collect_list(col("merged")).as("final") )

res.printSchema

res.show(false)

def f= (x: Any) => {

  val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

  val d1 = d.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]].head

  d1.toString
}

val fUdf = udf(f, StringType)

val d2 = res
  .withColumn("d", fUdf(col("final")))
  .drop("final")

d2.printSchema()

d2
  .show(false)
 }
 }
0 голосов
/ 23 сентября 2018

Вы можете создать функцию, которая ограничивает размер агрегированного столбца ArrayType, как показано ниже:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.Column

case class KV(k: String, v: String)

val df = Seq(
  ("key1", KV("internalKey1", "value1")),
  ("key1", KV("internalKey2", "value2")),
  ("key2", KV("internalKey3", "value3")),
  ("key2", KV("internalKey4", "value4")),
  ("key2", KV("internalKey5", "value5"))
).toDF("name", "merged")

def limitSize(n: Int, arrCol: Column): Column =
  array( (0 until n).map( arrCol.getItem ): _* )

df.
  groupBy("name").agg( collect_list(col("merged")).as("final") ).
  select( $"name", limitSize(2, $"final").as("final2") ).
  show(false)
// +----+----------------------------------------------+
// |name|final2                                        |
// +----+----------------------------------------------+
// |key1|[[internalKey1,value1], [internalKey2,value2]]|
// |key2|[[internalKey3,value3], [internalKey4,value4]]|
// +----+----------------------------------------------+
...