Сравнение `float` с` np.nan` в Spark Dataframe - PullRequest
6 голосов
/ 18 марта 2019

Это ожидаемое поведение? Я думал поднять проблему со Spark, но это кажется такой базовой функциональностью, что трудно представить, что здесь есть ошибка. Чего мне не хватает?

Python

import numpy as np

>>> np.nan < 0.0
False

>>> np.nan > 0.0
False

PySpark

from pyspark.sql.functions import col

df = spark.createDataFrame([(np.nan, 0.0),(0.0, np.nan)])
df.show()
#+---+---+
#| _1| _2|
#+---+---+
#|NaN|0.0|
#|0.0|NaN|
#+---+---+

df.printSchema()
#root
# |-- _1: double (nullable = true)
# |-- _2: double (nullable = true)

df.select(col("_1")> col("_2")).show()
#+---------+
#|(_1 > _2)|
#+---------+
#|     true|
#|    false|
#+---------+

1 Ответ

7 голосов
/ 18 марта 2019

Это ожидаемое и задокументированное поведение.Процитируем Семантика NaN официальный раздел Руководство по Spark SQL (выделено мной):

Специальная обработка для не-числа (NaN)при работе с типами float или double, которые не совсем соответствуют стандартной семантике с плавающей точкой.В частности:

  • NaN = NaN возвращает true.
  • В агрегатах все значения NaN группируются вместе.
  • NaN рассматривается как обычное значение в ключах соединения.
  • Значения NaN идут последними в возрастающем порядке, больше, чем любое другое числовое значение .

Ad Как вы видите, порядок упорядочения - не единственная разница, по сравнению с Python NaN.В частности, Spark считает NaN равным:

spark.sql("""
    WITH table AS (SELECT CAST('NaN' AS float) AS x, cast('NaN' AS float) AS y) 
    SELECT x = y, x != y FROM table
""").show()
+-------+-------------+
|(x = y)|(NOT (x = y))|
+-------+-------------+
|   true|        false|
+-------+-------------+

, тогда как обычный Python

float("NaN") == float("NaN"), float("NaN") != float("NaN")
(False, True)

и NumPy

np.nan == np.nan, np.nan != np.nan
(False, True)

don 't.

Вы можете проверить eqNullSafe docstring для дополнительных примеров.

Таким образом, чтобы получить желаемый результат, вам придется явно проверить NaN's

from pyspark.sql.functions import col, isnan, when

when(isnan("_1") | isnan("_2"), False).otherwise(col("_1") > col("_2"))
...