AFAIK, нет способа динамически итерировать по ArrayType()
, поэтому, если ваши данные уже находятся в массиве, у вас есть два варианта:
Вариант 1: взорвать, отфильтровать, собрать
Используйте pyspark.sql.functions.explode()
, чтобы превратить элементы массива в отдельные строки. Затем используйте pyspark.sql.DataFrame.where()
, чтобы отфильтровать нужные значения. Наконец, выполните groupBy()
и collect_set()
, чтобы собрать данные обратно в один ряд.
df_grouped = df.groupby("id").agg(F.collect_set("code").alias("codes"))
df_grouped.select("*", F.explode("codes").alias("exploded"))\
.where(~F.col("exploded").isin(["code2"]))\
.groupBy("id")\
.agg(F.collect_set("exploded").alias("codes"))\
.show()
#+---+-------+
#| id| codes|
#+---+-------+
#| a|[code1]|
#+---+-------+
Вариант 2: Используйте UDF
def filter_code(array):
bad_values={"code2"}
return [x for x in array if x not in bad_values]
filter_code_udf = F.udf(lambda x: filter_code(x), ArrayType(StringType()))
df_grouped = df.groupby("id").agg(F.collect_set("code").alias("codes"))
df_grouped.withColumn("codes_filtered", filter_code_udf("codes")).show()
#+---+--------------+--------------+
#| id| codes|codes_filtered|
#+---+--------------+--------------+
#| a|[code2, code1]| [code1]|
#+---+--------------+--------------+
Конечно, если вы начинаете с исходного кадра данных (до groupBy()
и collect_set()
), вы можете сначала отфильтровать нужные значения:
df.where(~F.col("code").isin(["code2"])).groupby("id").agg(F.collect_set("code")).show()
#+---+-----------------+
#| id|collect_set(code)|
#+---+-----------------+
#| a| [code1]|
#+---+-----------------+