Я не могу дать прямой ответ. Но пока кто-то еще не ответит, вы можете использовать код ниже
import numpy as np
import pyspark.sql.functions as F
df = sqlContext.createDataFrame([(1, np.nan), (1, 0.12), (2, np.nan)], ('id', 'prob1'))
df = df.withColumn(
'prob1',
F.when(
F.col('prob1') == 0,
F.lit(0.01)
).otherwise(
F.col('prob1')
)
)
df = df.fillna(0)
df = df.groupBy('id').agg(
F.sum(F.col('prob1')).alias('label')
)
df = df.withColumn(
'label',
F.when(
F.col('label') != 0,
F.lit(1)
).otherwise(
F.col('label')
)
)
df.show()