Редактировать: Как обсуждалось в комментариях, проблема для оригинальных методов могла быть из count
с использованием функций фильтрации или агрегирования, которые запускают ненужные операции сканирования данных. Ниже мы разбиваем массивы и выполняем агрегирование (подсчет) перед созданием окончательного столбца массива:
from pyspark.sql.functions import collect_list, struct
df = spark.createDataFrame([(2,[1,2]), (2,[1,2]), (3,[1,2,3]), (3,[1,2])],['timestamp', 'vars'])
df.selectExpr("timestamp", "explode(vars) as var") \
.groupby('timestamp','var') \
.count() \
.groupby("timestamp") \
.agg(collect_list(struct("var","count")).alias("data")) \
.selectExpr(
"timestamp",
"transform(data, x -> x.var) as indices",
"transform(data, x -> x.count) as values"
).selectExpr(
"timestamp",
"transform(sequence(0, array_max(indices)), i -> IFNULL(values[array_position(indices,i)-1],0)) as new_vars"
).show(truncate=False)
+---------+------------+
|timestamp|new_vars |
+---------+------------+
|3 |[0, 2, 2, 1]|
|2 |[0, 2, 2] |
+---------+------------+
Где:
(1) мы разбиваем массив и сделать count () для каждого timestamp
+ var
(2) groupby timestamp
и создать массив структур, содержащий два поля var
и count
(3 ) преобразовать массив структур в два массива: индексы и значения (аналогично тому, что мы определяем SparseVector)
(4) преобразовать последовательность sequence(0, array_max(indices))
, для каждого i в последовательности используйте array_position чтобы найти индекс i
в массиве indices
и затем извлечь значение из массива values
в той же позиции, см. Ниже:
IFNULL(values[array_position(indices,i)-1],0)
уведомление , что функция array_position использует индекс на основе 1, а индексация массива - на основе 0, поэтому в приведенном выше выражении мы имеем -1
.
Старые методы:
(1) Использовать transform + filter / size
from pyspark.sql.functions import flatten, collect_list
df.groupby('timestamp').agg(flatten(collect_list('vars')).alias('data')) \
.selectExpr(
"timestamp",
"transform(sequence(0, array_max(data)), x -> size(filter(data, y -> y = x))) as vars"
).show(truncate=False)
+---------+------------+
|timestamp|vars |
+---------+------------+
|3 |[0, 2, 2, 1]|
|2 |[0, 2, 2] |
+---------+------------+
(2) Использование агрегат функция:
df.groupby('timestamp').agg(flatten(collect_list('vars')).alias('data')) \
.selectExpr("timestamp", """
aggregate(
data,
/* use an array as zero_value, size = array_max(data))+1 and all values are zero */
array_repeat(0, int(array_max(data))+1),
/* increment the ith value of the Array by 1 if i == y */
(acc, y) -> transform(acc, (x,i) -> IF(i=y, x+1, x))
) as vars
""").show(truncate=False)