Десятичная точность для класса случая Spark Dataset Encoder - PullRequest
0 голосов
/ 07 декабря 2018

У меня есть данные CSV:

"id","price"
"1","79.07"
"2","91.27"
"3","85.6"

Чтение с использованием SparkSession:

def readToDs(resource: String, schema: StructType): Dataset = {
    sparkSession.read
      .option("header", "true")
      .schema(schema)
      .csv(resource)
      .as[ItemPrice]
}

Класс дела:

case class ItemPrice(id: Long, price: BigDecimal)

Набор данных печати:

def main(args: Array[String]): Unit = {
    val prices: Dataset = 
        readToDs("src/main/resources/app/data.csv", Encoders.product[ItemPrice].schema);
    prices.show();
}

Вывод:

+----------+--------------------+
|        id|               price|
+----------+--------------------+
|         1|79.07000000000000...|
|         2|91.27000000000000...|
|         3|85.60000000000000...|
+----------+--------------------+

Требуемый вывод:

+----------+--------+
|        id|   price|
+----------+--------+
|         1|   79.07|
|         2|   91.27|
|         3|   85.6 |
+----------+--------+

Опция, которую я уже знаю:

Определение схемы вручную с помощью жестко закодированного порядка столбцов иТипы данных, такие как:

def defineSchema(): StructType =
    StructType(
      Seq(StructField("id", LongType, nullable = false)) :+
        StructField("price", DecimalType(3, 2), nullable = false)
    )

И использовать его как:

val prices: Dataset = readToDs("src/main/resources/app/data.csv", defineSchema);

Как я могу установить точность (3,2) без ручного определения всей структуры?

Ответы [ 3 ]

0 голосов
/ 07 декабря 2018

Предполагая, что вы получите свой CSV как

scala> val df = Seq(("1","79.07","89.04"),("2","91.27","1.02"),("3","85.6","10.01")).toDF("item","price1","price2")
df: org.apache.spark.sql.DataFrame = [item: string, price1: string ... 1 more field]

scala> df.printSchema
root
 |-- item: string (nullable = true)
 |-- price1: string (nullable = true)
 |-- price2: string (nullable = true)

Вы можете разыграть его, как показано ниже

scala> val df2 = df.withColumn("price1",'price1.cast(DecimalType(4,2)))
df2: org.apache.spark.sql.DataFrame = [item: string, price1: decimal(4,2) ... 1 more field]

scala> df2.printSchema
root
 |-- item: string (nullable = true)
 |-- price1: decimal(4,2) (nullable = true)
 |-- price2: string (nullable = true)


scala>

Теперь, если вы знаете список десятичных столбцов из CSV .. смассив, вы можете сделать это динамически, как показано ниже

scala> import org.apache.spark.sql.types._
import org.apache.spark.sql.types._

scala> val decimal_cols = Array("price1","price2")
decimal_cols: Array[String] = Array(price1, price2)

scala> val df3 = decimal_cols.foldLeft(df){ (acc,r) => acc.withColumn(r,col(r).cast(DecimalType(4,2))) }
df3: org.apache.spark.sql.DataFrame = [item: string, price1: decimal(4,2) ... 1 more field]

scala> df3.show
+----+------+------+
|item|price1|price2|
+----+------+------+
|   1| 79.07| 89.04|
|   2| 91.27|  1.02|
|   3| 85.60| 10.01|
+----+------+------+


scala> df3.printSchema
root
 |-- item: string (nullable = true)
 |-- price1: decimal(4,2) (nullable = true)
 |-- price2: decimal(4,2) (nullable = true)


scala>

Помогает ли это?.

UPDATE1:

Чтение файла CSV с использованием inferSchema изатем динамически приводим все двойные поля к десятичному типу (4,2).

val df = spark.read.format("csv").option("header","true").option("inferSchema","true").load("in/items.csv")
df.show
df.printSchema()
val decimal_cols = df.schema.filter( x=> x.dataType.toString == "DoubleType" ).map(x=>x.name)
// or df.schema.filter( x=> x.dataType==DoubleType )
val df3 = decimal_cols.foldLeft(df){ (acc,r) => acc.withColumn(r,col(r).cast(DecimalType(4,2))) }
df3.printSchema()
df3.show()

Результаты:

+-----+------+------+
|items|price1|price2|
+-----+------+------+
|    1| 79.07| 89.04|
|    2| 91.27|  1.02|
|    3|  85.6| 10.01|
+-----+------+------+

root
 |-- items: integer (nullable = true)
 |-- price1: double (nullable = true)
 |-- price2: double (nullable = true)

root
 |-- items: integer (nullable = true)
 |-- price1: decimal(4,2) (nullable = true)
 |-- price2: decimal(4,2) (nullable = true)

+-----+------+------+
|items|price1|price2|
+-----+------+------+
|    1| 79.07| 89.04|
|    2| 91.27|  1.02|
|    3| 85.60| 10.01|
+-----+------+------+
0 голосов
/ 07 декабря 2018

Можно указать конвертер для входной схемы:

def defineDecimalType(schema: StructType): StructType = {
    new StructType(
      schema.map {
        case StructField(name, dataType, nullable, metadata) =>
          if (dataType.isInstanceOf[DecimalType])
            // Pay attention to max precision in the source data
            StructField(name, new DecimalType(20, 2), nullable, metadata)
          else 
            StructField(name, dataType, nullable, metadata)
      }.toArray
    )
} 

def main(args: Array[String]): Unit = {
    val prices: Dataset = 
        readToDs("src/main/resources/app/data.csv", defineDecimalType(Encoders.product[ItemPrice].schema));
    prices.show();
}

Недостатком этого подхода является то, что это сопоставление применяется к каждому столбцу, и если у вас есть ID, который не подходитс точной точностью (скажем, от ID = 10000 до DecimalType(3, 2)) вы получите исключение:

Причина: java.lang.IllegalArgumentException: требование не выполнено: десятичная точность 4 превышает максимальную точность 3 при scala.Predef $ .require (Predef.scala: 224) в org.apache.spark.sql.types.Decimal.set (Decimal.scala: 113) в org.apache.spark.sql.types.Decimal $ .apply (десятичный.scala: 426) в org.apache.spark.sql.execution.datasources.csv.CSVTypeCast $ .castTo (CSVInferSchema.scala: 273) в org.apache.spark.sql.execution.datasources.csv.CSVRelation $$ anonfun$ csvParser $ 3.apply (CSVRelation.scala: 125) в org.apache.spark.sql.execution.datasources.csv.CSVRelation $$ anonfun $ csvParser $ 3.apply (CSVRelation.scala: 94) в org.apache.spark.sql.execution.datasources.csv.CSVFileFormat $$ anonfun $ buildReader $ 1 $$ anonfun $ применять $ 2.Apply (CSVFileFormat.scala: 167) в org.apache.spark.sql.execution.datasources.csv.CSVFileFormat $$ anonfun $ buildReader $ 1 $$ anonfun $ apply $ 2.apply (CSVFileFormat.scala: 166)

Так вот почему важно поддерживать точность выше, чем наибольший десятичный знак в исходных данных:

if (dataType.isInstanceOf[DecimalType])
    StructField(name, new DecimalType(20, 2), nullable, metadata)
0 голосов
/ 07 декабря 2018

Я попытался загрузить пример данных, используя 2 разных файла CSV, и он работает нормально, и результаты ожидаются для следующего кода.Я использую Spark 2.3.1 для Windows.

//read with double quotes
val df1 = spark.read
.format("csv")
.option("header","true")
.option("inferSchema","true")
.option("nullValue","")
.option("mode","failfast")
.option("path","D:/bitbuket/spark-examples/53667822/string.csv")
.load()

df1.show
/*
scala> df1.show
+---+-----+
| id|price|
+---+-----+
|  1|79.07|
|  2|91.27|
|  3| 85.6|
+---+-----+
*/

//read with without quotes
val df2 = spark.read
.format("csv")
.option("header","true")
.option("inferSchema","true")
.option("nullValue","")
.option("mode","failfast")
.option("path","D:/bitbuket/spark-examples/53667822/int-double.csv")
.load()

df2.show

/*
scala> df2.show
+---+-----+
| id|price|
+---+-----+
|  1|79.07|
|  2|91.27|
|  3| 85.6|
+---+-----+
*/

Result with your data set

...