Таким образом, пока UDF делает то, что вам нужно, если вы ищете более производительный способ, который также чувствителен к памяти, способ сделать это - написать UDAF.К сожалению, API UDAF на самом деле не так расширяемо, как агрегатные функции, которые поставляются с искрой.Однако вы можете использовать их внутренние API для создания внутренних функций и делать то, что вам нужно.
Вот реализация для collect_list_limit
, которая в основном является копией внутренней CollectList
AggregateFunction в Spark.Я бы просто расширил его, но это класс дела.На самом деле все, что нужно, это переопределить методы update и merge для соблюдения переданного предела:
case class CollectListLimit(
child: Expression,
limitExp: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {
val limit = limitExp.eval( null ).asInstanceOf[Int]
def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
if( buffer.size < limit ) super.update(buffer, input)
else buffer
}
override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
if( buffer.size >= limit ) buffer
else if( other.size >= limit ) other
else ( buffer ++= other ).take( limit )
}
override def prettyName: String = "collect_list_limit"
}
И чтобы реально зарегистрировать его, мы можем сделать это через внутренний Spark FunctionRegistry
, который принимает имя истроитель, который фактически является функцией, которая создает CollectListLimit
, используя предоставленные выражения:
val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )
Редактировать:
Оказывается, добавление его во встроенную функцию работает, только если вы не создалиSparkContext пока что делает неизменным клоном при запуске.Если у вас есть существующий контекст, то это должно работать, чтобы добавить его с отражением:
val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )