Spark Scala - результат выравнивания скользящей средней с функцией pandas - PullRequest
0 голосов
/ 08 февраля 2020

Я пытаюсь преобразовать функцию панды скользящего среднего в искру scala. Тем не менее, похоже, что оба дают разные результаты.

Pandas код:

dummy = {"value": [10, 20, 30,40, 50, 60, 70, 80, 90, 100],
         "name": ["aa" for i in range(0,10)]}
df= pd.DataFrame(dummy, columns=['name', 'value'])
pprint(df)
pprint(df.groupby('name').rolling(2).mean().shift(1))

Искровой код

val df =List(("ABC", 10),
              ("ABC", 20),
              ("ABC", 30),
              ("ABC", 40),
              ("ABC", 50),
              ("ABC", 60),
              ("ABC", 70),
              ("ABC", 80),
              ("ABC", 90),
              ("ABC", 100)
            ).
          toDF("name", "value")

val window = Window.partitionBy($"name").orderBy($"value").rowsBetween(-2,1)
val df2 = df.withColumn("rolling_average", avg($"value") over(window))
display(df2)

Pandas Выход :

        value
name         
aa   0    NaN
     1    NaN
     2   15.0
     3   25.0
     4   35.0
     5   45.0
     6   55.0
     7   65.0
     8   75.0
     9   85.0

Искровой выход

+----+-----+---------------+
|name|value|rolling_average|
+----+-----+---------------+
| ABC|   10|           15.0|
| ABC|   20|           20.0|
| ABC|   30|           25.0|
| ABC|   40|           35.0|
| ABC|   50|           45.0|
| ABC|   60|           55.0|
| ABC|   70|           65.0|
| ABC|   80|           75.0|
| ABC|   90|           85.0|
| ABC|  100|           90.0|
+----+-----+---------------+

Есть ли способ, с помощью которого можно включить функцию искрового окна для получения аналогичного выхода как Pandas Функция?

1 Ответ

0 голосов
/ 08 февраля 2020

Просто изменил ваши строки между начальной точкой с 1 на -1.

Получил почти идентичный результат, затем изменил 2-ую запись 10 на ноль, чтобы дать вам точно такой же результат.

%scala
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{avg,col,lit,when}
val df =List(("ABC", 10),
              ("ABC", 20),
              ("ABC", 30),
              ("ABC", 40),
              ("ABC", 50),
              ("ABC", 60),
              ("ABC", 70),
              ("ABC", 80),
              ("ABC", 90),
              ("ABC", 100)
            ).
          toDF("name", "value")

val window = Window.partitionBy($"name").orderBy($"value").rowsBetween(-2, -1)
val df2 = df.withColumn("rolling_average", avg($"value") over(window)).withColumn("rolling_average", when(col("rolling_average")===10, lit(null)).otherwise(col("rolling_average")))
df2.show()

+----+-----+---------------+
|name|value|rolling_average|
+----+-----+---------------+
| ABC|   10|           null|
| ABC|   20|           null|
| ABC|   30|           15.0|
| ABC|   40|           25.0|
| ABC|   50|           35.0|
| ABC|   60|           45.0|
| ABC|   70|           55.0|
| ABC|   80|           65.0|
| ABC|   90|           75.0|
| ABC|  100|           85.0|
+----+-----+---------------+
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...