Здесь проще всего использовать pyspark.sql.functions.collect_set
во всех столбцах:
import pyspark.sql.functions as f
df.select(*[f.collect_set(c).alias(c) for c in df.columns]).show()
#+------+-----+---------+
#| COL_1|COL_2| COL_3|
#+------+-----+---------+
#|[B, A]| [C]|[F, E, D]|
#+------+-----+---------+
Очевидно, это возвращает данные в виде одной строки.
Если вместо этого вы хотите получить результат, который вы написали в своем вопросе (по одной строке на уникальное значение для каждого столбца), это выполнимо, но требует немало гимнастики pyspark (и любое решение, вероятно, будет гораздо менее эффективным).
Тем не менее, я представляю вам несколько вариантов:
Вариант 1: взорваться и присоединиться
Вы можете использовать pyspark.sql.functions.posexplode
, чтобы разбить элементы в наборе значений для каждого столбца вместе с индексом в массиве. Сделайте это для каждого столбца отдельно, а затем внешне объедините результирующий список DataFrames вместе, используя functools.reduce
:
from functools import reduce
unique_row = df.select(*[f.collect_set(c).alias(c) for c in df.columns])
final_df = reduce(
lambda a, b: a.join(b, how="outer", on="pos"),
(unique_row.select(f.posexplode(c).alias("pos", c)) for c in unique_row.columns)
).drop("pos")
final_df.show()
#+-----+-----+-----+
#|COL_1|COL_2|COL_3|
#+-----+-----+-----+
#| A| null| E|
#| null| null| D|
#| B| C| F|
#+-----+-----+-----+
Вариант 2: выбор по позиции
Сначала вычислите размер максимального массива и сохраните его в новом столбце max_length
. Затем выберите элементы из каждого массива, если значение существует по этому индексу.
Еще раз мы используем pyspark.sql.functions.posexplode
, но на этот раз мы просто создаем столбец для представления индекса в каждом массиве для извлечения.
Наконец, мы используем этот трюк , который позволяет использовать значение столбца в качестве параметра.
final_df= df.select(*[f.collect_set(c).alias(c) for c in df.columns])\
.withColumn("max_length", f.greatest(*[f.size(c) for c in df.columns]))\
.select("*", f.expr("posexplode(split(repeat(',', max_length-1), ','))"))\
.select(
*[
f.expr(
"case when size({c}) > pos then {c}[pos] else null end AS {c}".format(c=c))
for c in df.columns
]
)
final_df.show()
#+-----+-----+-----+
#|COL_1|COL_2|COL_3|
#+-----+-----+-----+
#| B| C| F|
#| A| null| E|
#| null| null| D|
#+-----+-----+-----+