Я хочу реализовать следующую формулу, используя pyspark:
Lx_BOP(1) = 1
Lx_BOP(n+1) = Lx_BOP(n) * (1 - rate(n))
Я создал эти тестовые данные:
termination_rate_input = [
["dummy_rate_flag", 1, 0.1],
["dummy_rate_flag", 2, 0.1],
["dummy_rate_flag", 3, 0.1],
["dummy_rate_flag", 4, 0.1],
["dummy_rate_flag", 5, 0.1],
]
input_schuma = StructType([
StructField("rate_flag", StringType(), True),
StructField("months_since_event", IntegerType(), True),
StructField("bop_monthly_scaled_rate", DoubleType(), True)
])
и эта логика:
def add_lx(rate_df):
df = rate_df
lx_window = W.partitionBy("rate_flag").orderBy(F.col("months_since_event"))
# Add bop_monthly_scaled_rate(n) to n+1 row
df = df.withColumn(
"_n_bop_monthly_scaled_rate",
(F.lit(1) - F.lag(F.col("bop_monthly_scaled_rate"), offset=1, default=1).over(lx_window)))
df = df.withColumn(
"_n_bop_monthly_scaled_rate",
F.when(F.col("months_since_event") == F.lit(1), F.lit(1)).otherwise(F.col("_n_bop_monthly_scaled_rate")))
# compute lx_bop based on _n_bop_lx and the bop_monthly_scaled_rate(n)
df = df.withColumn(
"lx_bop",
F.exp(F.sum(F.log(F.lag(F.col("_n_bop_monthly_scaled_rate"), offset=1, default=1.0))).over(lx_window)))
return df
эта часть F.exp(F.sum(F.log(F.lag(F.col("_n_bop_monthly_scaled_rate"), offset=1, default=1.0))).over(lx_window)))
используется для умножения каждой ячейки в строке (n) на ячейку в строке (n-1).
, но яполучаю исключение: java.util.concurrent.ExecutionException: java.lang.UnsupportedOperationException: Cannot generate code for expression: lag(CASE WHEN (input[1, int, true] = 1) THEN 1.0 ELSE (1.0 - input[3, double, true]) END, 1, 1.0)
Есть ли другой способ сделать это?