Как рассчитать совокупную сумму под лимит с помощью Spark? - PullRequest
3 голосов
/ 05 марта 2020

После нескольких попыток и некоторых исследований я застрял при попытке решить следующую проблему с Spark.

У меня есть Dataframe элементов с приоритетом и количеством.

+------+-------+--------+---+
|family|element|priority|qty|
+------+-------+--------+---+
|    f1| elmt 1|       1| 20|
|    f1| elmt 2|       2| 40|
|    f1| elmt 3|       3| 10|
|    f1| elmt 4|       4| 50|
|    f1| elmt 5|       5| 40|
|    f1| elmt 6|       6| 10|
|    f1| elmt 7|       7| 20|
|    f1| elmt 8|       8| 10|
+------+-------+--------+---+

У меня есть фиксированный лимит количества:

+------+--------+
|family|limitQty|
+------+--------+
|    f1|     100|
+------+--------+

Я хочу отметить как "ok" элементы, совокупная сумма которых находится ниже лимита. Вот ожидаемый результат:

+------+-------+--------+---+---+
|family|element|priority|qty| ok|
+------+-------+--------+---+---+
|    f1| elmt 1|       1| 20|  1| -> 20 < 100   => ok
|    f1| elmt 2|       2| 40|  1| -> 20 + 40 < 100  => ok
|    f1| elmt 3|       3| 10|  1| -> 20 + 40 + 10 < 100   => ok
|    f1| elmt 4|       4| 50|  0| -> 20 + 40 + 10 + 50 > 100   => ko 
|    f1| elmt 5|       5| 40|  0| -> 20 + 40 + 10 + 40 > 100   => ko  
|    f1| elmt 6|       6| 10|  1| -> 20 + 40 + 10 + 10 < 100   => ok
|    f1| elmt 7|       7| 20|  1| -> 20 + 40 + 10 + 10 + 20 < 100   => ok
|    f1| elmt 8|       8| 10|  0| -> 20 + 40 + 10 + 10 + 20 + 10 > 100   => ko
+------+-------+--------+---+---+  

Я пытаюсь решить, если с накопленной суммой:

    initDF
      .join(limitQtyDF, Seq("family"), "left_outer")
      .withColumn("cumulSum", sum($"qty").over(Window.partitionBy("family").orderBy("priority")))
      .withColumn("ok", when($"cumulSum" <= $"limitQty", 1).otherwise(0))
      .drop("cumulSum", "limitQty")

Но этого недостаточно, потому что элементы после элемента, который находится до предела не принимать во внимание. Я не могу найти способ решить это с помощью Spark. У вас есть идея?

Вот соответствующий Scala код:

    val sparkSession = SparkSession.builder()
      .master("local[*]")
      .getOrCreate()

    import sparkSession.implicits._

    val initDF = Seq(
      ("f1", "elmt 1", 1, 20),
      ("f1", "elmt 2", 2, 40),
      ("f1", "elmt 3", 3, 10),
      ("f1", "elmt 4", 4, 50),
      ("f1", "elmt 5", 5, 40),
      ("f1", "elmt 6", 6, 10),
      ("f1", "elmt 7", 7, 20),
      ("f1", "elmt 8", 8, 10)
    ).toDF("family", "element", "priority", "qty")

    val limitQtyDF = Seq(("f1", 100)).toDF("family", "limitQty")

    val expectedDF = Seq(
      ("f1", "elmt 1", 1, 20, 1),
      ("f1", "elmt 2", 2, 40, 1),
      ("f1", "elmt 3", 3, 10, 1),
      ("f1", "elmt 4", 4, 50, 0),
      ("f1", "elmt 5", 5, 40, 0),
      ("f1", "elmt 6", 6, 10, 1),
      ("f1", "elmt 7", 7, 20, 1),
      ("f1", "elmt 8", 8, 10, 0)
    ).toDF("family", "element", "priority", "qty", "ok").show()

Спасибо за помощь!

Ответы [ 3 ]

1 голос
/ 05 марта 2020

Решение показано ниже:

scala> initDF.show
+------+-------+--------+---+
|family|element|priority|qty|
+------+-------+--------+---+
|    f1| elmt 1|       1| 20|
|    f1| elmt 2|       2| 40|
|    f1| elmt 3|       3| 10|
|    f1| elmt 4|       4| 50|
|    f1| elmt 5|       5| 40|
|    f1| elmt 6|       6| 10|
|    f1| elmt 7|       7| 20|
|    f1| elmt 8|       8| 10|
+------+-------+--------+---+

scala> val df1 = initDF.groupBy("family").agg(collect_list("qty").as("comb_qty"), collect_list("priority").as("comb_prior"), collect_list("element").as("comb_elem"))
df1: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 2 more fields]

scala> df1.show
+------+--------------------+--------------------+--------------------+
|family|            comb_qty|          comb_prior|           comb_elem|
+------+--------------------+--------------------+--------------------+
|    f1|[20, 40, 10, 50, ...|[1, 2, 3, 4, 5, 6...|[elmt 1, elmt 2, ...|
+------+--------------------+--------------------+--------------------+


scala> val df2 = df1.join(limitQtyDF, df1("family") === limitQtyDF("family")).drop(limitQtyDF("family"))
df2: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 3 more fields]

scala> df2.show
+------+--------------------+--------------------+--------------------+--------+
|family|            comb_qty|          comb_prior|           comb_elem|limitQty|
+------+--------------------+--------------------+--------------------+--------+
|    f1|[20, 40, 10, 50, ...|[1, 2, 3, 4, 5, 6...|[elmt 1, elmt 2, ...|     100|
+------+--------------------+--------------------+--------------------+--------+


scala> def validCheck = (qty: Seq[Int], limit: Int) => {
     | var sum = 0
     | qty.map(elem => {
     | if (elem + sum <= limit) {
     | sum = sum + elem
     | 1}else{
     | 0
     | }})}
validCheck: (scala.collection.mutable.Seq[Int], Int) => scala.collection.mutable.Seq[Int]

scala> val newUdf = udf(validCheck)
newUdf: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function2>,ArrayType(IntegerType,false),Some(List(ArrayType(IntegerType,false), IntegerType)))

val df3 = df2.withColumn("valid", newUdf(col("comb_qty"),col("limitQty"))).drop("limitQty")
df3: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 3 more fields]

scala> df3.show
+------+--------------------+--------------------+--------------------+--------------------+
|family|            comb_qty|          comb_prior|           comb_elem|               valid|
+------+--------------------+--------------------+--------------------+--------------------+
|    f1|[20, 40, 10, 50, ...|[1, 2, 3, 4, 5, 6...|[elmt 1, elmt 2, ...|[1, 1, 1, 0, 0, 1...|
+------+--------------------+--------------------+--------------------+--------------------+

scala> val myUdf = udf((qty: Seq[Int], prior: Seq[Int], elem: Seq[String], valid: Seq[Int]) => {
     | elem zip prior zip qty zip valid map{
     | case (((a,b),c),d) => (a,b,c,d)}
     | }
     | )

scala> val df4 = df3.withColumn("combined", myUdf(col("comb_qty"),col("comb_prior"),col("comb_elem"),col("valid")))
df4: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 4 more fields]



scala> val df5 = df4.drop("comb_qty","comb_prior","comb_elem","valid")
df5: org.apache.spark.sql.DataFrame = [family: string, combined: array<struct<_1:string,_2:int,_3:int,_4:int>>]

scala> df5.show(false)
+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
|family|combined                                                                                                                                                        |
+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
|f1    |[[elmt 1, 1, 20, 1], [elmt 2, 2, 40, 1], [elmt 3, 3, 10, 1], [elmt 4, 4, 50, 0], [elmt 5, 5, 40, 0], [elmt 6, 6, 10, 1], [elmt 7, 7, 20, 1], [elmt 8, 8, 10, 0]]|
+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+

scala> val df6 = df5.withColumn("combined",explode(col("combined")))
df6: org.apache.spark.sql.DataFrame = [family: string, combined: struct<_1: string, _2: int ... 2 more fields>]

scala> df6.show
+------+------------------+
|family|          combined|
+------+------------------+
|    f1|[elmt 1, 1, 20, 1]|
|    f1|[elmt 2, 2, 40, 1]|
|    f1|[elmt 3, 3, 10, 1]|
|    f1|[elmt 4, 4, 50, 0]|
|    f1|[elmt 5, 5, 40, 0]|
|    f1|[elmt 6, 6, 10, 1]|
|    f1|[elmt 7, 7, 20, 1]|
|    f1|[elmt 8, 8, 10, 0]|
+------+------------------+

scala> val df7 = df6.select("family", "combined._1", "combined._2", "combined._3", "combined._4").withColumnRenamed("_1","element").withColumnRenamed("_2","priority").withColumnRenamed("_3", "qty").withColumnRenamed("_4","ok")
df7: org.apache.spark.sql.DataFrame = [family: string, element: string ... 3 more fields]

scala> df7.show
+------+-------+--------+---+---+
|family|element|priority|qty| ok|
+------+-------+--------+---+---+
|    f1| elmt 1|       1| 20|  1|
|    f1| elmt 2|       2| 40|  1|
|    f1| elmt 3|       3| 10|  1|
|    f1| elmt 4|       4| 50|  0|
|    f1| elmt 5|       5| 40|  0|
|    f1| elmt 6|       6| 10|  1|
|    f1| elmt 7|       7| 20|  1|
|    f1| elmt 8|       8| 10|  0|
+------+-------+--------+---+---+

Дайте мне знать, если это поможет !!

0 голосов
/ 05 марта 2020

Я новичок в Spark, поэтому это решение может быть неоптимальным. Я предполагаю, что значение 100 является входом для программы здесь. В этом случае:

case class Frame(family:String, element : String, priority : Int, qty :Int)

import scala.collection.JavaConverters._
val ans = df.as[Frame].toLocalIterator
  .asScala
  .foldLeft((Seq.empty[Int],0))((acc,a) => 
    if(acc._2 + a.qty <= 100) (acc._1 :+ a.priority, acc._2 + a.qty) else acc)._1

df.withColumn("OK" , when($"priority".isin(ans :_*), 1).otherwise(0)).show

приводит к:

+------+-------+--------+---+--------+
|family|element|priority|qty|OK      |
+------+-------+--------+---+--------+
|    f1| elmt 1|       1| 20|       1|
|    f1| elmt 2|       2| 40|       1|
|    f1| elmt 3|       3| 10|       1|
|    f1| elmt 4|       4| 50|       0|
|    f1| elmt 5|       5| 40|       0|
|    f1| elmt 6|       6| 10|       1|
|    f1| elmt 7|       7| 20|       1|
|    f1| elmt 8|       8| 10|       0|
+------+-------+--------+---+--------+

Идея состоит в том, чтобы просто получить итератор Scala и извлечь из него участвующие значения priority, а затем использовать те значения для фильтрации участвующих строк. Учитывая, что это решение собирает все данные в памяти на одной машине, оно может столкнуться с проблемами памяти, если размер кадра данных слишком велик для размещения в памяти.

0 голосов
/ 05 марта 2020

Другим способом сделать это будет подход, основанный на RDD, путем итерации строка за строкой.

var bufferRow: collection.mutable.Buffer[Row] = collection.mutable.Buffer.empty[Row]
var tempSum: Double = 0
val iterator = df.collect.iterator
while(iterator.hasNext){
  val record = iterator.next()
  val y = record.getAs[Integer]("qty")
  tempSum = tempSum + y
  print(record)
  if (tempSum <= 100.0 ) {
    bufferRow = bufferRow ++ Seq(transformRow(record,1))
  }
  else{
    bufferRow = bufferRow ++ Seq(transformRow(record,0))
    tempSum = tempSum - y
  }
}

Определение transformRow функции, которая используется для добавления столбца в строку.

def transformRow(row: Row,flag : Int): Row =  Row.fromSeq(row.toSeq ++ Array[Integer](flag))

Следующим шагом будет добавление в схему дополнительного столбца.

val newSchema = StructType(df.schema.fields ++ Array(StructField("C_Sum", IntegerType, false))

С созданием нового кадра данных.

val outputdf = spark.createDataFrame(spark.sparkContext.parallelize(bufferRow.toSeq),newSchema)

Выходной кадр данных:

+------+-------+--------+---+-----+
|family|element|priority|qty|C_Sum|
+------+-------+--------+---+-----+
|    f1|  elmt1|       1| 20|    1|
|    f1|  elmt2|       2| 40|    1|
|    f1|  elmt3|       3| 10|    1|
|    f1|  elmt4|       4| 50|    0|
|    f1|  elmt5|       5| 40|    0|
|    f1|  elmt6|       6| 10|    1|
|    f1|  elmt7|       7| 20|    1|
|    f1|  elmt8|       8| 10|    0|
+------+-------+--------+---+-----+
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...