Как рассчитать среднее значение в одной группе? - PullRequest
0 голосов
/ 18 января 2019

У меня есть такой фрейм данных:

+-----+---------+---------+
|Categ|      Amt|    price|
+-----+---------+---------+
|    A|      100|        1|
|    A|      180|        2|
|    A|      250|        3|
|    B|       90|        2|
|    B|      170|        3|
|    B|      280|        3|
+-----+---------+---------+

Я хочу сгруппировать по категории, чтобы вычислить среднюю цену в перекрывающихся диапазонах. Скажем, эти диапазоны [0-200] и [150-300]. Итак, вывод, который я хотел бы получить, выглядит следующим образом:

+-----+---------+---------+
|Categ|rang(Amt)| mean(price)|
+-----+---------+---------+
|    A|  [0-200]|      1.5|
|    A|[150-300]|      2.5|
|    B|  [0-200]|      2.5|
|    B|[150-300]|        3|
+-----+---------+---------+

Ответы [ 2 ]

0 голосов
/ 18 января 2019

Проверьте это.

scala> val df = Seq(("A",100,1),("A",180,2),("A",250,3),("B",90,2),("B",170,3),("B",280,3)).toDF("categ","amt","price")
df: org.apache.spark.sql.DataFrame = [categ: string, amt: int ... 1 more field]

scala> df.show(false)
+-----+---+-----+
|categ|amt|price|
+-----+---+-----+
|A    |100|1    |
|A    |180|2    |
|A    |250|3    |
|B    |90 |2    |
|B    |170|3    |
|B    |280|3    |
+-----+---+-----+

scala> val df2 = df.withColumn("newc",array(when('amt>=0 and 'amt <=200, map(lit("[0-200]"),'price)),when('amt>150 and 'amt<=300, map(lit("[150-3
00]"),'price))))
df2: org.apache.spark.sql.DataFrame = [categ: string, amt: int ... 2 more fields]

scala> val df3 = df2.select(col("*"), explode('newc).as("rangekv")).select(col("*"),explode('rangekv).as(Seq("range","price2")))
df3: org.apache.spark.sql.DataFrame = [categ: string, amt: int ... 5 more fields]

scala> df3.show(false)
+-----+---+-----+----------------------------------+----------------+---------+------+
|categ|amt|price|newc                              |rangekv         |range    |price2|
+-----+---+-----+----------------------------------+----------------+---------+------+
|A    |100|1    |[[[0-200] -> 1],]                 |[[0-200] -> 1]  |[0-200]  |1     |
|A    |180|2    |[[[0-200] -> 2], [[150-300] -> 2]]|[[0-200] -> 2]  |[0-200]  |2     |
|A    |180|2    |[[[0-200] -> 2], [[150-300] -> 2]]|[[150-300] -> 2]|[150-300]|2     |
|A    |250|3    |[, [[150-300] -> 3]]              |[[150-300] -> 3]|[150-300]|3     |
|B    |90 |2    |[[[0-200] -> 2],]                 |[[0-200] -> 2]  |[0-200]  |2     |
|B    |170|3    |[[[0-200] -> 3], [[150-300] -> 3]]|[[0-200] -> 3]  |[0-200]  |3     |
|B    |170|3    |[[[0-200] -> 3], [[150-300] -> 3]]|[[150-300] -> 3]|[150-300]|3     |
|B    |280|3    |[, [[150-300] -> 3]]              |[[150-300] -> 3]|[150-300]|3     |
+-----+---+-----+----------------------------------+----------------+---------+------+

scala> df3.groupBy('categ,'range).agg(avg('price)).orderBy('categ).show(false)
+-----+---------+----------+
|categ|range    |avg(price)|
+-----+---------+----------+
|A    |[0-200]  |1.5       |
|A    |[150-300]|2.5       |
|B    |[0-200]  |2.5       |
|B    |[150-300]|3.0       |
+-----+---------+----------+

scala>   

Вы также можете создать массив из range строк и взорвать их. Но в этом случае вы получите NULL после взрыва, поэтому вам нужно отфильтровать их.

scala> val df2 = df.withColumn("newc",array(when('amt>=0 and 'amt <=200, lit("[0-200]")),when('amt>150 and 'amt<=300,lit("[150-300]") )))
df2: org.apache.spark.sql.DataFrame = [categ: string, amt: int ... 2 more fields]

scala> val df3 = df2.select(col("*"), explode('newc).as("range"))
df3: org.apache.spark.sql.DataFrame = [categ: string, amt: int ... 3 more fields]

scala> df3.groupBy('categ,'range).agg(avg('price)).orderBy('categ).show(false)
+-----+---------+----------+
|categ|range    |avg(price)|
+-----+---------+----------+
|A    |[150-300]|2.5       |
|A    |[0-200]  |1.5       |
|A    |null     |2.0       |
|B    |[0-200]  |2.5       |
|B    |null     |2.5       |
|B    |[150-300]|3.0       |
+-----+---------+----------+

scala> df3.groupBy('categ,'range).agg(avg('price)).filter(" range is not null ").orderBy('categ).show(false)
+-----+---------+----------+
|categ|range    |avg(price)|
+-----+---------+----------+
|A    |[150-300]|2.5       |
|A    |[0-200]  |1.5       |
|B    |[0-200]  |2.5       |
|B    |[150-300]|3.0       |
+-----+---------+----------+


scala>
0 голосов
/ 18 января 2019

Вы можете отфильтровать свои значения перед группировкой, добавить столбец с именем диапазона и затем объединить результаты.

agg_range_0_200 = df
.filter('Amt > 0 and Amt < 200')
.groupBy('Categ').agg(mean('price'))
.withColumn('rang(Amt)', '[0-200]')

agg_range_150_300 = df
.filter('Amt > 150 and Amt < 300')
.groupBy('Categ').agg(mean('price'))
.withColumn('rang(Amt)', '[150-300]')

agg_range = agg_range_0_200.union(agg_range_150_300)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...