Я пытаюсь найти квантили для каждого столбца в таблице для различных фирм, использующих спарк 1.6
У меня около 5000 записей в firm_list и 300 записей в attr_lst.Количество записей в таблице составляет около 200000.
Я использую 10 исполнителей, каждый из которых имеет 16 ГБ памяти.
В настоящее время это занимает около 1 секунды для каждого вычисления квантиля и 2 минуты для всего преобразования.с такими темпами он будет работать около 10000 минут для 5000 фирм.
Пожалуйста, дайте мне знать, как я могу оптимизировать производительность.
from __future__ import print_function
from src.utils import sql_service, apa_constant as
constant,file_io_service
from pyspark.sql.functions import monotonicallyIncreasingId
from pyspark.sql.functions import lit,col,broadcast
from pyspark.ml.feature import Bucketizer
from pyspark.ml import Pipeline
from pyspark.sql.types import StructType
from pyspark.sql.types import StructField
from pyspark.sql.types import StringType,DoubleType
from pyspark.sql.types import *
from pyspark.ml.feature import *
from concurrent.futures import *
from functools import reduce
from pyspark.sql import DataFrame
import pyspark
import numpy as np
def generate_quantile_reports(spark_context, hive_context,
log,attribute_type, **kwargs):
sql = """describe
{}.apa_all_attrs_consortium""".format(kwargs['sem_db'])
op = hive_context.sql(sql)
res = op.withColumn("ordinal_position",
monotonicallyIncreasingId())
res.registerTempTable('attribs')
attr_lst = hive_context.sql(
"""select col_name from attribs where
ordinal_position > 24 AND col_name not like
'%vehicle%'
AND col_name not like '%cluster_num%'
AND col_name not like '%value_seg%' order by
ordinal_position""").collect()
sql = """select distinct firm_id, firm_name
from {}.apa_all_attrs_consortium where ud_rep = 1
and lower(channel) not in ('ria', 'clearing')
order by firm_id limit 5
""".format(kwargs['sem_db'])
dat = hive_context.sql(sql)
firm_list = dat.collect()
sql = """select entity_id, cast(firm_id as double), %s from
%s.apa_all_attrs_consortium where ud_rep = 1
and lower(channel) not in ('ria', 'clearing') cluster by
entity_id""" % (
", ".join("cast(" + str(attr.col_name) + " as double)" for
attr in attr_lst), kwargs['sem_db'])
df = hive_context.sql(sql)
qrtl_list = []
df.cache()
df.count()
counter = 0
for (fm,fnm) in firm_list:
df2 = df[df.firm_id == fm]
df2 = df2.replace(0, np.nan)
df_main = df2.drop('firm_id')
counter += 1
colNames = []
quartileList = []
bucketizerList = []
for var in attr_lst:
colNames.append(var.col_name)
jdf = df2._jdf
bindt = spark_context
._jvm.com.dstsystems.apa.util
.DFQuantileFunction.approxQuantile
(jdf,colNames,[0.0,0.25,0.5,0.75,1.0],0.0)
for i in range(len(bindt)):
quartile = sorted(list(set(list(bindt[i]))))
quartile = [-float("inf")] + quartile
quartile.insert(len(quartile),float("inf"))
quartile.insert(len(quartile),float("NaN"))
df_main = df_main.filter(df_main[colNames[i]].isNotNull())
bucketizerList
.append(Bucketizer().setInputCol(colNames[i])
.setOutputCol("{}_quantile".format(colNames[i]))
.setSplits(quartile))
path = " {}/tmpPqtDir/apa_team_broker_quartiles"
.format(kwargs['semzone_path'])
qrtl_list
.append(Pipeline(stages=bucketizerList)
.fit(df_main).transform(df_main))
finalDF = reduce(DataFrame.unionAll, qrtl_list)
finalDF.repartition(200)
.write.mode("overwrite").option("header","true").parquet(path)
df.unpersist()