Для Spark v 2.1 +
Вы можете воспользоваться pyspark.sql.functions.posexplode()
, чтобы разбить столбец вместе с индексом, который появляется в вашем массиве, а затем разделить результирующую позицию на n
, чтобы создать группы.
Например, вот результат использования posexplode()
в вашем DataFrame:
import pyspark.sql.functions as f
df.select('ID', f.posexplode('Example')).show()
#+---+---+------+
#| ID|pos| col|
#+---+---+------+
#| A| 0|[1, 2]|
#| A| 1|[3, 4]|
#| A| 2|[5, 6]|
#+---+---+------+
Обратите внимание, что мы получаем два столбца: pos
и col
вместо одного. Поскольку нам нужны группы n
, мы можем просто разделить pos
на n
и взять floor
, чтобы получить группы.
n = 2
df.select('ID', f.posexplode('Example'))\
.withColumn("group", f.floor(f.col("pos")/n))\
.show(truncate=False)
#+---+---+------+-----+
#|ID |pos|col |group|
#+---+---+------+-----+
#|A |0 |[1, 2]|0 |
#|A |1 |[3, 4]|0 |
#|A |2 |[5, 6]|1 |
#+---+---+------+-----+
Теперь сгруппируйте по "ID"
и "group"
и используйте pyspark.sql.functions.collect_list()
, чтобы получить желаемый результат.
df.select('ID', f.posexplode('Example'))\
.withColumn("group", f.floor(f.col("pos")/n))\
.groupBy("ID", "group")\
.agg(f.collect_list("col").alias("Example"))\
.sort("group")\
.drop("group")\
.show(truncate=False)
#+---+----------------------------------------+
#|ID |Example |
#+---+----------------------------------------+
#|A |[WrappedArray(1, 2), WrappedArray(3, 4)]|
#|A |[WrappedArray(5, 6)] |
#+---+----------------------------------------+
Вы увидите, что я также отсортировал по столбцу "group"
и удалил его, но это необязательно в зависимости от ваших потребностей.
Для более старых версий Spark
Есть несколько других методов для версий Spark ниже 2.1. Все эти методы выдают тот же результат, что и выше.
1. Использование udf
Вы можете использовать udf
, чтобы разбить ваш массив на группы. Например:
def get_groups(array, n):
return filter(lambda x: x, [array[i*n:(i+1)*n] for i in range(len(array))])
get_groups_of_2 = f.udf(
lambda x: get_groups(x, 2),
ArrayType(ArrayType(ArrayType(IntegerType())))
)
df.select("ID", f.explode(get_groups_of_2("Example")).alias("Example"))\
.show(truncate=False)
Функция get_groups()
примет массив и вернет массив групп из n элементов.
2. Использование rdd
Другим вариантом является сериализация до rdd
и использование функции get_groups()
внутри вызова для map()
. Затем преобразуйте обратно в DataFrame. Вам нужно будет указать схему для этого преобразования для правильной работы.
n = 2
schema = StructType(
[
StructField("ID", StringType()),
StructField("Example", ArrayType(ArrayType(ArrayType(IntegerType()))))
]
)
df.rdd.map(lambda x: (x["ID"], get_groups(x["Example"], n=n)))\
.toDF(schema)\
.select("ID", f.explode("Example").alias("Example"))\
.show(truncate=False)