Как подсчитать категориальные переменные по группам в PySpark? - PullRequest
0 голосов
/ 25 августа 2018

Я могу запустить следующий код и получить включенные выходные данные, но он не работает, если один и тот же AlertType появляется несколько раз для SessionID. Мне нужен способ получить значения, отличные от 1,0 в столбце OHE на выходе в этом случае. Ошибка была связана с итераторами.

Мне помогли этот вопрос и ответ: Как добавить редкие векторы после группировки с помощью Spark SQL?

columns=['SessionID','AlertType']
vals=[
    (1,0),
    (1,1),
    (1,2),
    (1,3),
    (1,4),
    (2,0),
    (2,1),
    (2,2),
    (2,3),
    (2,4),
]

df=spark.createDataFrame(vals,columns)
df.show()

+---------+---------+
|SessionID|AlertType|
+---------+---------+
|        1|        0|
|        1|        1|
|        1|        2|
|        1|        3|
|        1|        4|
|        2|        0|
|        2|        1|
|        2|        2|
|        2|        3|
|        2|        4|
+---------+---------+

from pyspark.sql.functions import collect_list,max,lit, udf
from pyspark.ml.linalg import Vectors,VectorUDT

def encode(arr,length):
    vec_args=length,[(x,1.0) for x in arr]
    return Vectors.sparse(*vec_args)
encode_udf=udf(encode,VectorUDT())

# do stringindexer stuff
from pyspark.ml.feature import StringIndexer
indexer=StringIndexer(inputCol='AlertType',outputCol='AlertTypeStrIndexed').fit(df)
df_strIndexed=indexer.transform(df)
df_strIndexed.show()

+---------+---------+-------------------+
|SessionID|AlertType|AlertTypeStrIndexed|
+---------+---------+-------------------+
|        1|        0|                2.0|
|        1|        1|                1.0|
|        1|        2|                3.0|
|        1|        3|                4.0|
|        1|        4|                0.0|
|        2|        0|                2.0|
|        2|        1|                1.0|
|        2|        2|                3.0|
|        2|        3|                4.0|
|        2|        4|                0.0|
+---------+---------+-------------------+

df_strIndexed.agg(max(df_strIndexed["AlertTypeStrIndexed"])).show()
feats = df_strIndexed.agg(max(df_strIndexed["AlertTypeStrIndexed"])).take(1)[0][0] + 1

df_OHE_grouped=df_strIndexed.groupBy("SessionID") \
               .agg(collect_list("AlertTypeStrIndexed")
               .alias("AlertArray")) \
               .select("SessionID", encode_udf("AlertArray", lit(feats)) \
                       .alias("OHE")).show(truncate=False)

+---------+-------------------------------------+
|SessionID|OHE                                  |
+---------+-------------------------------------+
|1        |(5,[0,1,2,3,4],[1.0,1.0,1.0,1.0,1.0])|
|2        |(5,[0,1,2,3,4],[1.0,1.0,1.0,1.0,1.0])|
...