Первая проблема, опубликованная OP.
# Create the DataFrame
valuesCol = [(True,1),(True,1),(True,2),(False,3),(False,1)]
df = sqlContext.createDataFrame(valuesCol,['shouldMerge','number'])
df.show()
+-----------+------+
|shouldMerge|number|
+-----------+------+
| true| 1|
| true| 1|
| true| 2|
| false| 3|
| false| 1|
+-----------+------+
# Packages to be imported
from pyspark.sql.window import Window
from pyspark.sql.functions import when, col, lag
# Register the dataframe as a view
df.registerTempTable('table_view')
df=sqlContext.sql(
'select shouldMerge, number, sum(number) over (partition by shouldMerge) as sum_number from table_view'
)
df = df.withColumn('number',when(col('shouldMerge')==True,col('sum_number')).otherwise(col('number')))
df.show()
+-----------+------+----------+
|shouldMerge|number|sum_number|
+-----------+------+----------+
| true| 4| 4|
| true| 4| 4|
| true| 4| 4|
| false| 3| 4|
| false| 1| 4|
+-----------+------+----------+
df = df.drop('sum_number')
my_window = Window.partitionBy().orderBy('shouldMerge')
df = df.withColumn('shouldMerge_lag', lag(col('shouldMerge'),1).over(my_window))
df.show()
+-----------+------+---------------+
|shouldMerge|number|shouldMerge_lag|
+-----------+------+---------------+
| false| 3| null|
| false| 1| false|
| true| 4| false|
| true| 4| true|
| true| 4| true|
+-----------+------+---------------+
df = df.where(~((col('shouldMerge')==True) & (col('shouldMerge_lag')==True))).drop('shouldMerge_lag')
df.show()
+-----------+------+
|shouldMerge|number|
+-----------+------+
| false| 3|
| false| 1|
| true| 4|
+-----------+------+
Для второй проблемы, опубликованной OP
# Create the DataFrame
valuesCol = [(1,2),(1,1),(2,1),(1,2),(-1,3),(-1,1)]
df = sqlContext.createDataFrame(valuesCol,['mergeId','number'])
df.show()
+-------+------+
|mergeId|number|
+-------+------+
| 1| 2|
| 1| 1|
| 2| 1|
| 1| 2|
| -1| 3|
| -1| 1|
+-------+------+
# Packages to be imported
from pyspark.sql.window import Window
from pyspark.sql.functions import when, col, lag
# Register the dataframe as a view
df.registerTempTable('table_view')
df=sqlContext.sql(
'select mergeId, number, sum(number) over (partition by mergeId) as sum_number from table_view'
)
df = df.withColumn('number',when(col('mergeId') > 0,col('sum_number')).otherwise(col('number')))
df.show()
+-------+------+----------+
|mergeId|number|sum_number|
+-------+------+----------+
| 1| 5| 5|
| 1| 5| 5|
| 1| 5| 5|
| 2| 1| 1|
| -1| 3| 4|
| -1| 1| 4|
+-------+------+----------+
df = df.drop('sum_number')
my_window = Window.partitionBy('mergeId').orderBy('mergeId')
df = df.withColumn('mergeId_lag', lag(col('mergeId'),1).over(my_window))
df.show()
+-------+------+-----------+
|mergeId|number|mergeId_lag|
+-------+------+-----------+
| 1| 5| null|
| 1| 5| 1|
| 1| 5| 1|
| 2| 1| null|
| -1| 3| null|
| -1| 1| -1|
+-------+------+-----------+
df = df.where(~((col('mergeId') > 0) & (col('mergeId_lag').isNotNull()))).drop('mergeId_lag')
df.show()
+-------+------+
|mergeId|number|
+-------+------+
| 1| 5|
| 2| 1|
| -1| 3|
| -1| 1|
+-------+------+
Документация: lag () - Возвращает значение смещения строк перед текущей строкой.