Накопительный продукт UDF для Spark SQL - PullRequest
1 голос
/ 09 апреля 2020

Я видел в других постах, что это делается для фреймов данных: { ссылка }

Но я пытаюсь выяснить, как я могу написать udf для кумулятивного продукта.

Если у меня есть очень базовый c стол

Input data:
+----+
| val|
+----+
| 1  |
| 2  |
| 3  |
+----+

Если я хочу взять сумму, я могу просто сделать что-то вроде

sparkSession.createOrReplaceTempView("table")
spark.sql("""Select SUM(table.val) from table""").show(100, false)

и это просто работает, потому что SUM - предопределенная функция.

Как бы я определил что-то похожее для умножения (или даже как я могу реализовать сумму в UDF сам)?

Попробовать следующее

sparkSession.createOrReplaceTempView("_Period0")

val prod = udf((vals:Seq[Decimal]) => vals.reduce(_ * _))
spark.udf.register("prod",prod)

spark.sql("""Select prod(table.vals) from table""").show(100, false)

Я получаю следующую ошибку:

Message: cannot resolve 'UDF(vals)' due to data type mismatch: argument 1 requires array<decimal(38,18)> type, however, 'table.vals' is of decimal(28,14)

Очевидно, что каждая указанная c ячейка не является массивом, но, похоже, udf должен принимать массив для выполнения агрегации. Возможно ли это даже при искре sql?

1 Ответ

1 голос
/ 09 апреля 2020

Вы можете реализовать это через UserDefinedAggregateFunction Вам необходимо определить несколько функций для работы со входными данными и значениями буфера.

Быстрый пример для функции продукта с использованием просто удваивается как тип:

  import org.apache.spark.sql.expressions.MutableAggregationBuffer
  import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
  import org.apache.spark.sql.Row
  import org.apache.spark.sql.types._


    class myUDAF extends UserDefinedAggregateFunction {

      // inputSchema for the function
      override def inputSchema: StructType = {
        new StructType().add("val", DoubleType, nullable = true)
      }

     //Schema for the inner UDAF buffer, in the product case, you just need an accumulator
     override def bufferSchema: StructType = StructType(StructField("accumulated", DoubleType) :: Nil)

    //OutputDataType
    override def dataType: DataType = DoubleType

    override def deterministic: Boolean = true

    //Initicla buffer value 1 for product
    override def initialize(buffer: MutableAggregationBuffer) = buffer(0) = 1.0

    //How to update the buffer, for product you just need to perform a product between the two elements (buffer & input)
    override def update(buffer: MutableAggregationBuffer, input: Row) = {
        buffer(0) = buffer.getAs[Double](0) * input.getAs[Double](0)
      }

      //Merge results with the previous buffered value (product as well here)
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getAs[Double](0) * buffer2.getAs[Double](0)
      }

      //Function on how to return the value
      override def evaluate(buffer: Row) = buffer.getAs[Double](0)

    }

Тогда вы можете зарегистрировать функцию , как вы сделали бы с любым другим UDF:

spark.udf.register("prod", new myUDAF)

РЕЗУЛЬТАТ

scala> spark.sql("Select prod(val) from table").show
+-----------+
|myudaf(val)|
+-----------+
|        6.0|
+-----------+

Дополнительную документацию можно найти здесь

...