как вычислить дисконтированную будущую совокупную сумму с помощью оконных функций spark pyspark или sql - PullRequest
0 голосов
/ 28 сентября 2018

Могу ли я рассчитать дисконтированную будущую совокупную сумму, используя spark sql?Ниже приведен пример, который вычисляет недисконтированную сумму будущей суммы, используя оконные функции, и я жестко закодировал то, что имею в виду под дисконтированной суммой суммы:

from pyspark.sql.window import Window


def undiscountedCummulativeFutureReward(df):
    windowSpec = Window \
        .partitionBy('user') \
        .orderBy('time') \
        .rangeBetween(0, Window.unboundedFollowing)

    tot_reward = F.sum('reward').over(windowSpec)

    df_tot_reward = df.withColumn('undiscounted', tot_reward)
    return df_tot_reward


def makeData(spark, gamma=0.5):
    data = [{'user': 'bob', 'time': 3, 'reward': 10, 'discounted_cum': 10 + (gamma * 9) + ((gamma ** 2) * 11)},
            {'user': 'bob', 'time': 4, 'reward': 9, 'discounted_cum': 9 + gamma * 11},
            {'user': 'bob', 'time': 5, 'reward': 11, 'discounted_cum': 11.0},
            {'user': 'jo', 'time': 4, 'reward': 6, 'discounted_cum': 6 + gamma * 7},
            {'user': 'jo', 'time': 5, 'reward': 7, 'discounted_cum': 7.0},
            ]
    schema = T.StructType([T.StructField('user', T.StringType(), False),
                           T.StructField('time', T.IntegerType(), False),
                           T.StructField('reward', T.IntegerType(), False),
                           T.StructField('discounted_cum', T.FloatType(), False)])

    return spark.createDataFrame(data=data, schema=schema)


def main(spark):
    df = makeData(spark)
    df = undiscountedCummulativeFutureReward(df)
    df.orderBy('user', 'time').show()
    return df

Когда вы запустите ее, вы получите:

+----+----+------+--------------+------------+
|user|time|reward|discounted_cum|undiscounted|
+----+----+------+--------------+------------+
| bob|   3|    10|         17.25|          30|
| bob|   4|     9|          14.5|          20|
| bob|   5|    11|          11.0|          11|
|  jo|   4|     6|           9.5|          13|
|  jo|   5|     7|           7.0|           7|
+----+----+------+--------------+------------+

Это дисконтированное значение sum \gamma^k r_k for k=0 to \infinity

Мне интересно, могу ли я вычислить дисконтированный столбец с помощью оконных функций, например, ввести столбец с рангом, литерал с гаммой, умножить все вместе - новсе еще не совсем ясно - я полагаю, я могу сделать это с какой-то UDF, но я думаю, что сначала мне нужно collect_as_list всех пользователей, вернуть новый список с суммой скидок и затем взорвать список.

1 Ответ

0 голосов
/ 28 сентября 2018

Предположим, вы начинали со следующего DataFrame:

df.show()
#+----+----+------+
#|user|time|reward|
#+----+----+------+
#| bob|   3|    10|
#| bob|   4|     9|
#| bob|   5|    11|
#|  jo|   4|     6|
#|  jo|   5|     7|
#+----+----+------+

Вы можете присоединить этот DataFrame к себе в столбце user и сохранить только те строки, где столбец time правой таблицыбольше или равно столбцу времени левой таблицы.Мы упростили это, добавив псевдонимы DataFrames l и r.

. После объединения вы можете группировать по user, time и reward из левой таблицы и агрегировать столбец вознаграждений.с правого стола.Однако, похоже, что groupBy, за которым следует orderBy, не гарантирует сохранение этого порядка , поэтому вы должны использовать Window, чтобы быть явным.

from pyspark.sql import Window, functions as f

w = Window.partitionBy("user", "l.time", "l.reward").orderBy("r.time")

df = df.alias("l").join(df.alias("r"), on="user")\
    .where("r.time>=l.time")\
    .select(
        "user",
        f.col("l.time").alias("time"),
        f.col("l.reward").alias("reward"),
        f.collect_list("r.reward").over(w).alias("rewards")
    )

df.show()
#+----+----+------+-----------+
#|user|time|reward|    rewards|
#+----+----+------+-----------+
#|  jo|   4|     6|        [6]|
#|  jo|   4|     6|     [6, 7]|
#|  jo|   5|     7|        [7]|
#| bob|   3|    10|       [10]|
#| bob|   3|    10|    [10, 9]|
#| bob|   3|    10|[10, 9, 11]|
#| bob|   4|     9|        [9]|
#| bob|   4|     9|    [9, 11]|
#| bob|   5|    11|       [11]|
#+----+----+------+-----------+

Теперьу вас есть все элементы, необходимые для вычисления столбца discounted_cum.

Spark 2.1 и выше:

Вы можете использовать pyspark.sql.functions.posexplode, чтобы взорвать массив rewards вместе с индексом всписок.Это создаст новую строку для каждого значения в массиве rewards.Используйте distinct для удаления дубликатов, которые были введены с помощью функции Window (вместо groupBy).

Мы назовем индекс k и награду rk.Теперь вы можете применить свою функцию, используя pyspark.sql.functions.pow

gamma = 0.5

df.select("user", "time", "reward", f.posexplode("rewards").alias("k", "rk"))\
    .distinct()\
    .withColumn("discounted", f.pow(f.lit(gamma), f.col("k"))*f.col("rk"))\
    .groupBy("user", "time")\
    .agg(f.first("reward").alias("reward"), f.sum("discounted").alias("discounted_cum"))\
    .show()
#+----+----+------+--------------+
#|user|time|reward|discounted_cum|
#+----+----+------+--------------+
#| bob|   3|    10|         17.25|
#| bob|   4|     9|          14.5|
#| bob|   5|    11|          11.0|
#|  jo|   4|     6|           9.5|
#|  jo|   5|     7|           7.0|
#+----+----+------+--------------+

Старые версии Spark

Для более старых версий spark вам придется использовать row_number()-1, чтобы получить значения для k после использования explode:

df.select("user", "time", "reward", f.explode("rewards").alias("rk"))\
    .distinct()\
    .withColumn(
        "k",
        f.row_number().over(Window.partitionBy("user", "time").orderBy("time"))-1
    )\
    .withColumn("discounted", f.pow(f.lit(gamma), f.col("k"))*f.col("rk"))\
    .groupBy("user", "time")\
    .agg(f.first("reward").alias("reward"), f.sum("discounted").alias("discounted_cum"))\
    .show()
#+----+----+------+--------------+
#|user|time|reward|discounted_cum|
#+----+----+------+--------------+
#|  jo|   4|     6|           9.5|
#|  jo|   5|     7|           7.0|
#| bob|   3|    10|         17.25|
#| bob|   4|     9|          14.5|
#| bob|   5|    11|          11.0|
#+----+----+------+--------------+
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...