Использование Graphframes + Structured Streaming приводит к OOM даже при очень низких объемах данных - PullRequest
0 голосов
/ 09 января 2020

У меня есть простое приложение структурированной потоковой передачи PySpark, которое преобразует входящие сообщения в график (используя GraphFrames). Упрощенный пример кода приведен ниже.

Код будет работать в течение ~ 50 пакетов до сбоя с ошибкой "G C Overhead ..." или "Размер кучи ...". К этому моменту у него будет график не более 100 вершин и 300 ребер.

В журнале logs + traceback показан вызов labelPropogation() как запрос страницы памяти, которая затем вызывает ошибку OOM. Если я переключу функцию для pagerank() или пользовательского алгоритма Прегеля (используя graphframes.lib.pregel), cra sh все равно произойдет, все еще после ~ 50 пакетов. При использовании алгоритма Прегела трассировка покажет виновную функцию агрегирования в aggMsgs().

Это похоже на утечку памяти, но я не могу отследить кого-либо, кто испытал именно это. Так может проблема в моих настройках? (Обратите внимание, что в приведенном ниже коде я работаю в локальном режиме, но проблема также существует при отправке в кластер, поэтому я оставил в коде настройки, относящиеся к работе в кластере).

Интересно, использование GraphX ​​под капотом может быть частью проблемы?

Любые идеи очень с благодарностью получены!

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import udf, window,count, col, sum
from pyspark.sql.types import *
import jsonpickle

LABEL_PROP_MAX_ITER = 5

def main():

    spark = (
        SparkSession
        .builder
        .appName('GraphFrames_Test')
        .master("local[2]")
        .config('spark.jars.packages', 'org.apache.spark:spark-sql-kafka-0-10_2.11:2.4.3,graphframes:graphframes:0.7.0-spark2.4-s_2.11')
        .config('spark.sql.shuffle.partitions', 1)
        .config('spark.python.worker.memory', '2G')
        .config('spark.executor.memory', '2G')
        .config('spark.driver.memory', '2G')
        .config('spark.cleaner.ttl', '10s')
        .config('spark.cleaner.periodicGC.interval', '5min')
        .config('spark.cleaner.referenceTracking.blocking.shuffle', 'true')
        .config('spark.cleaner.referenceTracking.cleanCheckpoints', 'true')
        .config('spark.driver.maxResultSize', '500m')
        .config('spark.graphx.pregel.checkpointInterval', 2)
        .config('spark.executor.cores', '2G')
        .config('spark.dynamicAllocation.enabled', 'true')
        .config('spark.shuffle.service.enabled', 'true')
        .config('spark.dynamicAllocation.maxExecutors', 2)
        .getOrCreate()
    )

    # These imports can only happen once the graphframes JAR has been registered
    from graphframes import GraphFrame
    from graphframes.lib import Pregel

    # Checkpointing
    spark.sparkContext.setCheckpointDir('/tmp')

    # Initialise logger
    log4j = spark.sparkContext._jvm.org.apache.log4j
    log4j.LogManager.getRootLogger().setLevel(log4j.Level.WARN)

    # Define the stream to process
    stream = (
        spark
        .readStream
        .format("kafka")
        .option("kafka.bootstrap.servers", 'localhost:9092')
        .option("subscribe", 'raw')
        .option("startingOffsets", "latest")
        .load()
    )

    # Schema for the elements we will use
    schema = StructType([
        StructField("src", LongType(), True),
        StructField("src", TimestampType(), True),
        StructField("dst", LongType(), False)
    ])

    def parser(serialised_packet):
        packet = jsonpickle.decode(serialised_packet)
        src = packet.src
        dst = packet.dst
        created_at = packet.created_at
        return [src, dst, created_at]

    # Register UDF
    parser_udf = udf(lambda value: parser(value), schema)

    msgs = (
        stream
        .select(parser_udf('value').alias('msg'))
        .select(
            col('msg.src').alias('src'),
            col('msg.dst').alias('dst'),
            col('msg.created_at').alias('created_at')
        )
        .withWatermark('created_at', '1 day')
    )

    # Weighted edges
    edges_df = (
        msgs
        .groupBy(window('created_at', '7 days', '1 day'), 'src', 'dst')
        .count()
        .withColumnRenamed('count', 'weight')
    )

    def label_clusters(edges_df: DataFrame, batch_id):

        if not edges_df.rdd.isEmpty():

            # Generate a DF of vertices
            vertices_df = (
                edges_df.select(edges_df.src.alias('id'))
                .union(
                    edges_df.select(edges_df.dst.alias('id'))
                )
            )

            g = GraphFrame(vertices_df, edges_df)

        clusters_df = g.labelPropagation(LABEL_PROP_MAX_ITER)

        labels = (
            clusters_df
            .groupBy('label')
            .count()
            .toPandas()
            .label
            .values
        )

        print(f"Batch: [{batch_id}]")
        print(f"Vertex Count: [{g.vertices.count()}]")
        print(f"Edge Count: [{g.edges.count()}]")
        print(f"Num. Labels: [{len(labels)}]")
        print("********************************")

        g.unpersist(blocking=True)

    (
        edges_df
        .writeStream
        .outputMode("complete")
        .foreachBatch(label_clusters)
        .start()
    )


if __name__ == '__main__':
    main()

...