Вы можете создать пользовательскую функцию, чтобы получить максимальный индекс
from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType
max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
df = df.withColumn("topicID", max_index("topicDistribution"))
Пример
>>> from pyspark.sql import functions as f
>>> from pyspark.sql.types import IntegerType
>>> df = spark.createDataFrame([{"topicDistribution": [0.2, 0.3, 0.5]}])
>>> df.show()
+-----------------+
|topicDistribution|
+-----------------+
| [0.2, 0.3, 0.5]|
+-----------------+
>>> max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
>>> df.withColumn("topicID", max_index("topicDistribution")).show()
+-----------------+-------+
|topicDistribution|topicID|
+-----------------+-------+
| [0.2, 0.3, 0.5]| 2|
+-----------------+-------+
Редактировать:
Поскольку вы упомянули, что списки в topicDistribution
являются numpy массивами, вы можете обновить max_index
udf
следующим образом:
max_index = f.udf(lambda x: x.tolist().index(max(x)), IntegerType())