Фильтр Spark DataFrame не работает должным образом со случайным - PullRequest
0 голосов
/ 19 февраля 2019

Это мой DataFrame

df.groupBy($"label").count.show
+-----+---------+                                                               
|label|    count|
+-----+---------+
|  0.0|400000000|
|  1.0| 10000000|
+-----+---------+

Я пытаюсь выполнить выборку записей с меткой == 0.0 со следующим:

val r = scala.util.Random
val df2 = df.filter($"label" === 1.0 || r.nextDouble > 0.5) // keep 50% of 0.0

Мой вывод выглядит следующим образом:

df2.groupBy($"label").count.show
+-----+--------+                                                                
|label|   count|
+-----+--------+
|  1.0|10000000|
+-----+--------+

1 Ответ

0 голосов
/ 19 февраля 2019

r.nextDouble является константой в выражении, поэтому фактическая оценка сильно отличается от того, что вы имеете в виду.В зависимости от фактического значения выборки это либо

scala> r.setSeed(0)

scala> $"label" === 1.0 || r.nextDouble > 0.5
res0: org.apache.spark.sql.Column = ((label = 1.0) OR true)

, либо

scala> r.setSeed(4096)

scala> $"label" === 1.0 || r.nextDouble > 0.5
res3: org.apache.spark.sql.Column = ((label = 1.0) OR false)

, поэтому после упрощения это просто:

true

(ведение всех записей)или

label = 1.0 

(сохраняя только те, что вы наблюдали) соответственно.

Для генерации случайных чисел вы должны использовать соответствующую функцию SQL

scala> import org.apache.spark.sql.functions.rand
import org.apache.spark.sql.functions.rand

scala> $"label" === 1.0 || rand > 0.5
res1: org.apache.spark.sql.Column = ((label = 1.0) OR (rand(3801516599083917286) > 0.5))

, хотя Spark уже предоставляет стратифицированные инструменты отбора проб:

df.stat.sampleBy(
  "label",  // column
  Map(0.0 -> 0.5, 1.0 -> 1.0),  // fractions
  42 // seed 
)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...