Вызов cv.fit()
возвращает CountVectorizerModel
, который (AFAIK) хранит словарь, но не сохраняет счет.Словарь является свойством модели (он должен знать, какие слова считать), но счетчик является свойством DataFrame (а не модели).Вы можете применить функцию преобразования подобранной модели, чтобы получить значения для любого DataFrame.
При этом есть два способа получить желаемый результат.
1.Использование существующей модели векторизатора счета
Вы можете использовать pyspark.sql.functions.explode()
и pyspark.sql.functions.collect_list()
, чтобы собрать весь корпус в один ряд.В иллюстративных целях давайте рассмотрим новый DataFrame df2
, который содержит несколько слов, невидимых встроенной CountVectorizer
:
import pyspark.sql.functions as f
df2 = sqlCtx.createDataFrame([(0, ["a", "b", "c", "x", "y"]),
(1, ["a", "b", "b", "c", "a"])],
["label", "raw"])
combined_df = (
df2.select(f.explode('raw').alias('col'))
.select(f.collect_list('col').alias('raw'))
)
combined_df.show(truncate=False)
#+------------------------------+
#|raw |
#+------------------------------+
#|[a, b, c, x, y, a, b, b, c, a]|
#+------------------------------+
Затем используйте подобранную модель, чтобы преобразовать ее в число и собрать результаты:
counts = model.transform(combined_df).select('vectors').collect()
print(counts)
#[Row(vectors=SparseVector(3, {0: 3.0, 1: 3.0, 2: 2.0}))]
Далее zip
подсчитывает и словарь вместе и использует конструктор dict
, чтобы получить желаемый результат:
print(dict(zip(model.vocabulary, counts[0]['vectors'].values)))
#{u'a': 3.0, u'b': 3.0, u'c': 2.0}
Как вы правильно указали в комментариях, это будет учитывать только те слова, которые являются частью словаря CountVectorizerModel
.Любые другие слова будут игнорироваться.Следовательно, мы не видим записей для "x"
или "y"
.
2.Используйте агрегатные функции DataFrame
. Или вы можете пропустить CountVectorizer
и получить вывод, используя groupBy()
.Это более общее решение в том смысле, что оно даст подсчет всех слов в кадре данных, а не только слов в словаре:
counts = df2.select(f.explode('raw').alias('col')).groupBy('col').count().collect()
print(counts)
#[Row(col=u'x', count=1), Row(col=u'y', count=1), Row(col=u'c', count=2),
# Row(col=u'b', count=3), Row(col=u'a', count=3)]
Теперь просто используйте dict
понимание:
print({row['col']: row['count'] for row in counts})
#{u'a': 3, u'b': 3, u'c': 2, u'x': 1, u'y': 1}
Здесь у нас также есть счет для "x"
и "y"
.