У меня есть простое приложение структурированной потоковой передачи PySpark, которое преобразует входящие сообщения в график (используя GraphFrames). Упрощенный пример кода приведен ниже.
Код будет работать в течение ~ 50 пакетов до сбоя с ошибкой "G C Overhead ..." или "Размер кучи ...". К этому моменту у него будет график не более 100 вершин и 300 ребер.
В журнале logs + traceback показан вызов labelPropogation()
как запрос страницы памяти, которая затем вызывает ошибку OOM. Если я переключу функцию для pagerank()
или пользовательского алгоритма Прегеля (используя graphframes.lib.pregel
), cra sh все равно произойдет, все еще после ~ 50 пакетов. При использовании алгоритма Прегела трассировка покажет виновную функцию агрегирования в aggMsgs()
.
Это похоже на утечку памяти, но я не могу отследить кого-либо, кто испытал именно это. Так может проблема в моих настройках? (Обратите внимание, что в приведенном ниже коде я работаю в локальном режиме, но проблема также существует при отправке в кластер, поэтому я оставил в коде настройки, относящиеся к работе в кластере).
Интересно, использование GraphX под капотом может быть частью проблемы?
Любые идеи очень с благодарностью получены!
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import udf, window,count, col, sum
from pyspark.sql.types import *
import jsonpickle
LABEL_PROP_MAX_ITER = 5
def main():
spark = (
SparkSession
.builder
.appName('GraphFrames_Test')
.master("local[2]")
.config('spark.jars.packages', 'org.apache.spark:spark-sql-kafka-0-10_2.11:2.4.3,graphframes:graphframes:0.7.0-spark2.4-s_2.11')
.config('spark.sql.shuffle.partitions', 1)
.config('spark.python.worker.memory', '2G')
.config('spark.executor.memory', '2G')
.config('spark.driver.memory', '2G')
.config('spark.cleaner.ttl', '10s')
.config('spark.cleaner.periodicGC.interval', '5min')
.config('spark.cleaner.referenceTracking.blocking.shuffle', 'true')
.config('spark.cleaner.referenceTracking.cleanCheckpoints', 'true')
.config('spark.driver.maxResultSize', '500m')
.config('spark.graphx.pregel.checkpointInterval', 2)
.config('spark.executor.cores', '2G')
.config('spark.dynamicAllocation.enabled', 'true')
.config('spark.shuffle.service.enabled', 'true')
.config('spark.dynamicAllocation.maxExecutors', 2)
.getOrCreate()
)
# These imports can only happen once the graphframes JAR has been registered
from graphframes import GraphFrame
from graphframes.lib import Pregel
# Checkpointing
spark.sparkContext.setCheckpointDir('/tmp')
# Initialise logger
log4j = spark.sparkContext._jvm.org.apache.log4j
log4j.LogManager.getRootLogger().setLevel(log4j.Level.WARN)
# Define the stream to process
stream = (
spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", 'localhost:9092')
.option("subscribe", 'raw')
.option("startingOffsets", "latest")
.load()
)
# Schema for the elements we will use
schema = StructType([
StructField("src", LongType(), True),
StructField("src", TimestampType(), True),
StructField("dst", LongType(), False)
])
def parser(serialised_packet):
packet = jsonpickle.decode(serialised_packet)
src = packet.src
dst = packet.dst
created_at = packet.created_at
return [src, dst, created_at]
# Register UDF
parser_udf = udf(lambda value: parser(value), schema)
msgs = (
stream
.select(parser_udf('value').alias('msg'))
.select(
col('msg.src').alias('src'),
col('msg.dst').alias('dst'),
col('msg.created_at').alias('created_at')
)
.withWatermark('created_at', '1 day')
)
# Weighted edges
edges_df = (
msgs
.groupBy(window('created_at', '7 days', '1 day'), 'src', 'dst')
.count()
.withColumnRenamed('count', 'weight')
)
def label_clusters(edges_df: DataFrame, batch_id):
if not edges_df.rdd.isEmpty():
# Generate a DF of vertices
vertices_df = (
edges_df.select(edges_df.src.alias('id'))
.union(
edges_df.select(edges_df.dst.alias('id'))
)
)
g = GraphFrame(vertices_df, edges_df)
clusters_df = g.labelPropagation(LABEL_PROP_MAX_ITER)
labels = (
clusters_df
.groupBy('label')
.count()
.toPandas()
.label
.values
)
print(f"Batch: [{batch_id}]")
print(f"Vertex Count: [{g.vertices.count()}]")
print(f"Edge Count: [{g.edges.count()}]")
print(f"Num. Labels: [{len(labels)}]")
print("********************************")
g.unpersist(blocking=True)
(
edges_df
.writeStream
.outputMode("complete")
.foreachBatch(label_clusters)
.start()
)
if __name__ == '__main__':
main()