Можно ли каким-либо образом настроить приведенный ниже код Pyspark MLib, который рассчитывает квантили для повышения производительности? - PullRequest
0 голосов
/ 24 апреля 2019

Я пытаюсь найти квантили для каждого столбца в таблице для различных фирм, использующих спарк 1.6

У меня около 5000 записей в firm_list и 300 записей в attr_lst.Количество записей в таблице составляет около 200000.

Я использую 10 исполнителей, каждый из которых имеет 16 ГБ памяти.

В настоящее время это занимает около 1 секунды для каждого вычисления квантиля и 2 минуты для всего преобразования.с такими темпами он будет работать около 10000 минут для 5000 фирм.

Пожалуйста, дайте мне знать, как я могу оптимизировать производительность.

    from __future__ import print_function
    from src.utils import sql_service, apa_constant as 
    constant,file_io_service
    from pyspark.sql.functions import monotonicallyIncreasingId
    from pyspark.sql.functions import lit,col,broadcast
    from pyspark.ml.feature import Bucketizer
    from pyspark.ml import Pipeline
    from pyspark.sql.types import StructType
    from pyspark.sql.types import StructField
    from pyspark.sql.types import StringType,DoubleType
    from pyspark.sql.types import *
    from pyspark.ml.feature import *
    from concurrent.futures import *
    from functools import reduce
    from pyspark.sql import DataFrame
    import pyspark
    import numpy as np

    def generate_quantile_reports(spark_context, hive_context, 
          log,attribute_type, **kwargs):

        sql = """describe 
        {}.apa_all_attrs_consortium""".format(kwargs['sem_db'])
        op = hive_context.sql(sql)
        res = op.withColumn("ordinal_position", 
                             monotonicallyIncreasingId())
       res.registerTempTable('attribs')

       attr_lst = hive_context.sql(
                            """select col_name from attribs where 
                              ordinal_position > 24 AND col_name not like 
                              '%vehicle%'
                            AND col_name not like '%cluster_num%'
                            AND col_name not like '%value_seg%' order by 
                            ordinal_position""").collect()

        sql = """select distinct firm_id, firm_name
              from {}.apa_all_attrs_consortium where ud_rep = 1
              and lower(channel) not in ('ria', 'clearing')
              order by firm_id limit 5
              """.format(kwargs['sem_db'])
        dat = hive_context.sql(sql)
        firm_list = dat.collect()
        sql = """select entity_id, cast(firm_id as double), %s from 
                 %s.apa_all_attrs_consortium where ud_rep = 1
                 and lower(channel) not in ('ria', 'clearing') cluster by 
                 entity_id""" % (
                  ", ".join("cast(" + str(attr.col_name) + " as double)" for 
                  attr in attr_lst), kwargs['sem_db'])

        df = hive_context.sql(sql)
        qrtl_list = []
        df.cache()
        df.count()
        counter = 0
       for (fm,fnm) in firm_list:
          df2 = df[df.firm_id == fm]
          df2 = df2.replace(0, np.nan)
          df_main = df2.drop('firm_id')
          counter += 1
          colNames = []
          quartileList = []
          bucketizerList = []
          for var in attr_lst:
              colNames.append(var.col_name)
          jdf = df2._jdf
          bindt = spark_context
                  ._jvm.com.dstsystems.apa.util
                  .DFQuantileFunction.approxQuantile 
                 (jdf,colNames,[0.0,0.25,0.5,0.75,1.0],0.0)
          for i in range(len(bindt)):            
            quartile = sorted(list(set(list(bindt[i]))))
            quartile = [-float("inf")] + quartile
            quartile.insert(len(quartile),float("inf"))
            quartile.insert(len(quartile),float("NaN"))
            df_main = df_main.filter(df_main[colNames[i]].isNotNull())
            bucketizerList                         
            .append(Bucketizer().setInputCol(colNames[i])       
            .setOutputCol("{}_quantile".format(colNames[i]))
            .setSplits(quartile))
          path = " {}/tmpPqtDir/apa_team_broker_quartiles"
                  .format(kwargs['semzone_path'])
          qrtl_list
          .append(Pipeline(stages=bucketizerList)
          .fit(df_main).transform(df_main))
        finalDF = reduce(DataFrame.unionAll, qrtl_list)
        finalDF.repartition(200)
        .write.mode("overwrite").option("header","true").parquet(path)
        df.unpersist()
...