Просто используйте Dataset
API.Здесь, в Scala, но Java будет почти идентична:
val rdd = sc.parallelize(Seq(
(1,200,"a"), (2,300,"a"), (1,300,"b"), (2,400,"a"), (2,500,"b"),
(3,200,"a"), (3,400,"b"), (1,500,"a"), (2,400,"b"), (3,500,"a"),
(1,200,"b")
))
val df = rdd.toDF("k1", "v", "k2")
df.groupBy("k1", "k2").mean("v").orderBy("k1", "k2").show
+---+---+------+
| k1| k2|avg(v)|
+---+---+------+
| 1| a| 350.0|
| 1| b| 250.0|
| 2| a| 350.0|
| 2| b| 450.0|
| 3| a| 350.0|
| 3| b| 400.0|
+---+---+------+
С картой RDD первым должен быть составной ключ:
rdd
.map(x => ((x._1, x._3), (x._2, 1.0)))
.reduceByKey((x, y) => (x._1 + y._1, x._2 + y._2))
.mapValues(x => x._1 / x._2)
.take(6).foreach(println)
((2,a),350.0)
((3,b),400.0)
((1,b),250.0)
((1,a),350.0)
((3,a),350.0)
((2,b),450.0)