Предположим, вы начинали со следующего DataFrame:
df.show()
#+----+----+------+
#|user|time|reward|
#+----+----+------+
#| bob| 3| 10|
#| bob| 4| 9|
#| bob| 5| 11|
#| jo| 4| 6|
#| jo| 5| 7|
#+----+----+------+
Вы можете присоединить этот DataFrame к себе в столбце user
и сохранить только те строки, где столбец time
правой таблицыбольше или равно столбцу времени левой таблицы.Мы упростили это, добавив псевдонимы DataFrames l
и r
.
. После объединения вы можете группировать по user
, time
и reward
из левой таблицы и агрегировать столбец вознаграждений.с правого стола.Однако, похоже, что groupBy
, за которым следует orderBy
, не гарантирует сохранение этого порядка , поэтому вы должны использовать Window
, чтобы быть явным.
from pyspark.sql import Window, functions as f
w = Window.partitionBy("user", "l.time", "l.reward").orderBy("r.time")
df = df.alias("l").join(df.alias("r"), on="user")\
.where("r.time>=l.time")\
.select(
"user",
f.col("l.time").alias("time"),
f.col("l.reward").alias("reward"),
f.collect_list("r.reward").over(w).alias("rewards")
)
df.show()
#+----+----+------+-----------+
#|user|time|reward| rewards|
#+----+----+------+-----------+
#| jo| 4| 6| [6]|
#| jo| 4| 6| [6, 7]|
#| jo| 5| 7| [7]|
#| bob| 3| 10| [10]|
#| bob| 3| 10| [10, 9]|
#| bob| 3| 10|[10, 9, 11]|
#| bob| 4| 9| [9]|
#| bob| 4| 9| [9, 11]|
#| bob| 5| 11| [11]|
#+----+----+------+-----------+
Теперьу вас есть все элементы, необходимые для вычисления столбца discounted_cum
.
Spark 2.1 и выше:
Вы можете использовать pyspark.sql.functions.posexplode
, чтобы взорвать массив rewards
вместе с индексом всписок.Это создаст новую строку для каждого значения в массиве rewards
.Используйте distinct
для удаления дубликатов, которые были введены с помощью функции Window
(вместо groupBy
).
Мы назовем индекс k
и награду rk
.Теперь вы можете применить свою функцию, используя pyspark.sql.functions.pow
gamma = 0.5
df.select("user", "time", "reward", f.posexplode("rewards").alias("k", "rk"))\
.distinct()\
.withColumn("discounted", f.pow(f.lit(gamma), f.col("k"))*f.col("rk"))\
.groupBy("user", "time")\
.agg(f.first("reward").alias("reward"), f.sum("discounted").alias("discounted_cum"))\
.show()
#+----+----+------+--------------+
#|user|time|reward|discounted_cum|
#+----+----+------+--------------+
#| bob| 3| 10| 17.25|
#| bob| 4| 9| 14.5|
#| bob| 5| 11| 11.0|
#| jo| 4| 6| 9.5|
#| jo| 5| 7| 7.0|
#+----+----+------+--------------+
Старые версии Spark
Для более старых версий spark вам придется использовать row_number()-1
, чтобы получить значения для k
после использования explode
:
df.select("user", "time", "reward", f.explode("rewards").alias("rk"))\
.distinct()\
.withColumn(
"k",
f.row_number().over(Window.partitionBy("user", "time").orderBy("time"))-1
)\
.withColumn("discounted", f.pow(f.lit(gamma), f.col("k"))*f.col("rk"))\
.groupBy("user", "time")\
.agg(f.first("reward").alias("reward"), f.sum("discounted").alias("discounted_cum"))\
.show()
#+----+----+------+--------------+
#|user|time|reward|discounted_cum|
#+----+----+------+--------------+
#| jo| 4| 6| 9.5|
#| jo| 5| 7| 7.0|
#| bob| 3| 10| 17.25|
#| bob| 4| 9| 14.5|
#| bob| 5| 11| 11.0|
#+----+----+------+--------------+