В приведенном ниже фрагменте вторая агрегация завершается неудачно (что неудивительно):
java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema нельзя привести к spark_test..Record
package spark_test
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession}
import org.scalatest.FunSuite
case class Record(k1: String, k2: String, v: Long) extends Serializable
class MyAggregator extends Aggregator[Record, Long, Long] {
override def zero: Long = 0
override def reduce(b: Long, a: Record): Long = a.v + b
override def merge(b1: Long, b2: Long): Long = b1 + b2
override def finish(reduction: Long): Long = reduction
override def bufferEncoder: Encoder[Long] = Encoders.scalaLong
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
class TypeSafeAggTest extends FunSuite {
lazy val spark: SparkSession = {
SparkSession
.builder()
.master("local")
.appName("spark test")
.getOrCreate()
}
test("agg flow") {
import spark.sqlContext.implicits._
val df: DataFrame = Seq(
("a", "b", 1),
("a", "b", 1),
("c", "d", 1)
).toDF("k1", "k2", "v")
val aggregator = new MyAggregator()
.toColumn.name("output")
df.as[Record]
.groupByKey(_.k1)
.agg(aggregator)
.show(truncate = false) // < --- works #######
df.as[Record]
.groupBy($"k1", $"k2")
.agg(aggregator)
.show(truncate = false) // < --- fails runtime #######
}
}
Существует очень упрощенная примерная страница из официальных документов, но она не охватывает использование типов безопасных агрегаторов с группировкой (поэтому неясно, поддерживается ли такой случай).
http://spark.apachecn.org/docs/en/2.2.0/sql-programming-guide.html#type-safe-user-defined-aggregate-functions
Существует ли способ группировки по нескольким ключам при использовании агрегаторов с безопасным типом Spark?