Как получить индекс самого высокого значения в списке на строку в Spark DataFrame? [PySpark] - PullRequest
0 голосов
/ 28 января 2020

Я выполнил моделирование LDA topi c и сохранил его в lda_model.

После преобразования исходного набора входных данных я извлекаю DataFrame. Одним из столбцов является topicDistribution, где вероятность того, что эта строка принадлежит каждой топи c из модели LDA. Поэтому я хочу получить индекс максимального значения в списке на строку.

df -- | 'list_of_words' | 'index ' | 'topicDistribution' | 
       ['product','...']     0       [0.08,0.2,0.4,0.0001]
          .....             ...         ........

Я хочу преобразовать df так, чтобы был добавлен дополнительный столбец, представляющий собой argmax списка topicDistribution для каждой строки.

df_transformed --  | 'list_of_words' | 'index' | 'topicDistribution' | 'topicID' |
                    ['product','...']     0     [0.08,0.2,0.4,0.0001]      2
                       ......            ....         .....              ....

Как бы я это сделал?

1 Ответ

0 голосов
/ 28 января 2020

Вы можете создать пользовательскую функцию, чтобы получить максимальный индекс

from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType

max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
df = df.withColumn("topicID", max_index("topicDistribution"))

Пример

>>> from pyspark.sql import functions as f
>>> from pyspark.sql.types import IntegerType 
>>> df = spark.createDataFrame([{"topicDistribution": [0.2, 0.3, 0.5]}])
>>> df.show()
+-----------------+
|topicDistribution|
+-----------------+
|  [0.2, 0.3, 0.5]|
+-----------------+

>>> max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
>>> df.withColumn("topicID", max_index("topicDistribution")).show()
+-----------------+-------+
|topicDistribution|topicID|
+-----------------+-------+
|  [0.2, 0.3, 0.5]|      2|
+-----------------+-------+

Редактировать:

Поскольку вы упомянули, что списки в topicDistribution являются numpy массивами, вы можете обновить max_index udf следующим образом:

max_index = f.udf(lambda x: x.tolist().index(max(x)), IntegerType())
...