Лучший способ сделать это - преобразовать вложенные массивы в их собственные строки, чтобы вы могли использовать один groupBy
.Таким образом, вы можете сделать все это в одной агрегации вместо 100 (или более).Ключом к этому является использование posexplode
, которое превратит каждую запись в массиве в новую строку с индексом, в котором она была расположена в массиве.
Например:
import org.apache.spark.sql.functions.{posexplode, collect_list}
val data = Seq(
(Seq(1, 2, 3, 4, 5)),
(Seq(2, 3, 4, 5, 6)),
(Seq(3, 4, 5, 6, 7))
)
val df = data.toDF
val df2 = df.
select(posexplode($"value")).
groupBy($"pos").
agg(sum($"col") as "sum")
// At this point you will have rows with the index and the sum
df2.orderBy($"pos".asc).show
Выводит DataFrame следующим образом:
+---+---+
|pos|sum|
+---+---+
| 0| 6|
| 1| 9|
| 2| 12|
| 3| 15|
| 4| 18|
+---+---+
Или, если вы хотите, чтобы они были в одной строке, вы могли бы объявить что-то вроде этого:
df2.groupBy().agg(collect_list(struct($"pos", $"sum")) as "list").show
Значения в столбце Array не будутне может быть отсортировано, но вы можете написать UDF для сортировки по полю pos и удалить поле pos, если хотите это сделать.
Обновлено за комментарий
Если описанный выше подход не работает с какими-либо другими агрегатами, которые вы пытаетесь сделать, вам нужно будет определить свой собственный UDAF.Общая идея здесь заключается в том, что вы говорите Spark, как объединять значения для одного и того же ключа внутри раздела для создания промежуточных значений, а затем как объединять эти промежуточные значения между разделами для создания окончательного значения для каждого ключа.Определив класс UDAF, вы можете использовать его в вызове aggs
с любыми другими агрегатами, которые вы хотели бы сделать.
Вот быстрый пример, который я выбил.Обратите внимание, что он принимает длину массива, и, вероятно, его следует сделать более защищенным от ошибок, но он поможет вам в этом.
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
class ArrayCombine extends UserDefinedAggregateFunction {
// The input this aggregation will receive (each row)
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", ArrayType(IntegerType)) :: Nil)
// Your intermediate state as you are updating with data from each row
override def bufferSchema: StructType = StructType(
StructType(StructField("value", ArrayType(IntegerType)) :: Nil)
)
// This is the output type of your aggregatation function.
override def dataType: DataType = ArrayType(IntegerType)
override def deterministic: Boolean = true
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = (0 until 100).toArray
}
// Given a new input row, update our state
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val sums = buffer.getSeq[Int](0)
val newVals = input.getSeq[Int](0)
buffer(0) = sums.zip(newVals).map { case (a, b) => a + b }
}
// After we have finished computing intermediate values for each partition, combine the partitions
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val sums1 = buffer1.getSeq[Int](0)
val sums2 = buffer2.getSeq[Int](0)
buffer1(0) = sums1.zip(sums2).map { case (a, b) => a + b }
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
buffer.getSeq[Int](0)
}
}
Затем вызовите его так:
val arrayUdaf = new ArrayCombine()
df.groupBy().agg(arrayUdaf($"value")).show