Суммирование массива в столбце с использованием агрегатной функции - PullRequest
0 голосов
/ 28 октября 2019

Я пытаюсь суммировать поле, содержащее массив,

a = sc.parallelize([("a", [1,1,1]),
                    ("a", [2,2])])

a = a.toDF(["g", "arr_val"])

a.registerTempTable('a')

sql = """
select 
aggregate(arr_val, 0, (acc, x) -> acc + x) as sum
from a
"""

spark.sql(sql).show()

Но я сталкиваюсь со следующей ошибкой:

An error occurred while calling o24.sql.
: org.apache.spark.sql.AnalysisException: cannot resolve 'aggregate(a.`arr_val`, 0, lambdafunction((CAST(namedlambdavariable() AS BIGINT) + namedlambdavariable()), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))' due to data type mismatch: argument 3 requires int type, however, 'lambdafunction((CAST(namedlambdavariable() AS BIGINT) + namedlambdavariable()), namedlambdavariable(), namedlambdavariable())' is of bigint type.; line 3 pos 0;

Как мне заставить это работать

1 Ответ

0 голосов
/ 28 октября 2019

Вам необходимо привести значения в аккумуляторе, например, к числу с плавающей запятой:

a = sc.parallelize([("a", [1,1,1]),
                    ("a", [2,2])])
a = a.toDF(["g", "arr_val"])

a.registerTempTable('a')

sql = """
select 
aggregate(arr_val, cast(0 as float), (acc, x) -> acc + cast(x as float)) as sum
from a
"""

spark.sql(sql).show()
...