Писпарк НЛП - CountVectorizer Max DF или TF. Как отфильтровать общие вхождения из набора данных - PullRequest
0 голосов
/ 03 июля 2018

Я использую CountVectorizer, чтобы подготовить набор данных для ML. Я хочу отфильтровать редкие слова и для этого использую параметр CountVectorizer, minDF или minTF. Я также хотел бы удалить элементы, которые «часто» появляются в моем наборе данных. Я не вижу параметр maxTF или maxDF, который я могу установить. Есть ли хороший способ сделать это?

df = spark.createDataFrame(
[(0, ["a", "b", "c","b"]), (1, ["a", "b", "b", "c", "a"])],
["label", "raw"])

Так что в этом случае, если бы я хотел удалить параметры, которые появлялись «4» раза или 40% времени, и те, которые появлялись 2 раза или меньше. Это уберет «b» и «c».

В настоящее время я запускаю CountVectorizer(minDf=3......) для нижней границы запроса. Как я могу отфильтровать элементы, которые появляются чаще, чем я хочу моделировать.

1 Ответ

0 голосов
/ 05 июля 2018

Полагаю, вы запрашиваете параметр CountVectorizer, но похоже, что пока нет параметров для этого. Это не простой или практичный способ сделать это простым, но это работает. Я надеюсь, что это поможет вам:

from pyspark.sql.types import *
from pyspark.sql import functions as F

df = spark.createDataFrame(
[(0, ["a", "b", "c","b"]), (1, ["a", "b", "b", "c", "a"])],
["label", "raw"])

counts_df = df \
    .select(F.explode('raw').alias('testCol')) \
    .groupby('testCol') \
    .agg(F.count('testCol').alias('count')).persist() # this will be used multiple times

total = counts_df \
    .agg(F.sum('count').alias('total')) \
    .rdd.take(1)[0]['total']
min_times = 3
max_times = total * 0.4
filtered_elements = counts_df \
    .filter((min_times>F.col('count')) | (F.col('count')>max_times)) \
    .select('testCol') \
    .rdd.map(lambda row: row['testCol']) \
    .collect()

def removeElements(arr):
    return list(set(arr) - set(filtered_elements))

remove_udf = F.udf(removeElements, ArrayType(StringType()))
filtered_df = df \
    .withColumn('raw', remove_udf('raw'))

Результаты:

filtered_df.show()
+-----+---+
|label|raw|
+-----+---+
|    0|[a]|
|    1|[a]|
+-----+---+
...