У меня есть PySpark DataFrame, похожий на этот:
ID | value | period
a | 100 | 1
a | 100 | 1
b | 100 | 1
a | 100 | 2
b | 100 | 2
a | 100 | 3
Для каждого периода (1, 2, 3)
Я хочу отфильтровать данные, где период меньше или равен этому числу, а затем суммировать столбец значений для каждого идентификатора.
Так, например, период 1 даст (a:200, b:100)
, период 2 даст (a:300, b:200)
, а период 3 даст (a:400, b:200)
.
В данный момент я делаю это в цикле:
vals = [('a', 100, 1),
('a', 100, 1),
('b', 100, 1),
('a', 100, 2),
('b', 100, 2),
('a', 100, 3)]
cols = ['ID', 'value', 'period']
df = spark.createDataFrame(vals, cols)
for p in (1, 2, 3):
df_filter = df[df['period'] <= p]
results = df_filter.groupBy('ID').agg({'value':'sum'})
Затем я преобразовываю «результаты» в панды и добавляю их в один DataFrame.
Есть ли лучший способ сделать это без использования цикла? (на практике у меня есть сотни периодов).