Ниже приведен код, который вы ищете
df.groupBy("COL1").agg(countDistinct("COL2"),countDistinct("COL3"),count($"*")).show
======= Проверено ниже ============
scala> val lst = List(("a","x","d"),("b","D","s"),("ss","kk","ll"),("a","y","e"),("b","c","y"),("a","x","y"));
lst: List[(String, String, String)] = List((a,x,d), (b,D,s), (ss,kk,ll), (a,y,e), (b,c,y), (a,x,y))
scala> val rdd=sc.makeRDD(lst);
rdd: org.apache.spark.rdd.RDD[(String, String, String)] = ParallelCollectionRDD[7] at makeRDD at <console>:26
scala> val df = rdd.toDF("COL1","COL2","COL3");
df: org.apache.spark.sql.DataFrame = [COL1: string, COL2: string ... 1 more field]
scala> df.printSchema
root
|-- COL1: string (nullable = true)
|-- COL2: string (nullable = true)
|-- COL3: string (nullable = true)
scala> df.groupBy("COL1").agg(countDistinct("COL2"),countDistinct("COL3"),count($"*")).show
+----+--------------------+--------------------+--------+
|COL1|count(DISTINCT COL2)|count(DISTINCT COL3)|count(1)|
+----+--------------------+--------------------+--------+
| ss| 1| 1| 1|
| b| 2| 2| 2|
| a| 2| 3| 3|
+----+--------------------+--------------------+--------+
scala>