Предполагая, что вы получите свой CSV как
scala> val df = Seq(("1","79.07","89.04"),("2","91.27","1.02"),("3","85.6","10.01")).toDF("item","price1","price2")
df: org.apache.spark.sql.DataFrame = [item: string, price1: string ... 1 more field]
scala> df.printSchema
root
|-- item: string (nullable = true)
|-- price1: string (nullable = true)
|-- price2: string (nullable = true)
Вы можете разыграть его, как показано ниже
scala> val df2 = df.withColumn("price1",'price1.cast(DecimalType(4,2)))
df2: org.apache.spark.sql.DataFrame = [item: string, price1: decimal(4,2) ... 1 more field]
scala> df2.printSchema
root
|-- item: string (nullable = true)
|-- price1: decimal(4,2) (nullable = true)
|-- price2: string (nullable = true)
scala>
Теперь, если вы знаете список десятичных столбцов из CSV .. смассив, вы можете сделать это динамически, как показано ниже
scala> import org.apache.spark.sql.types._
import org.apache.spark.sql.types._
scala> val decimal_cols = Array("price1","price2")
decimal_cols: Array[String] = Array(price1, price2)
scala> val df3 = decimal_cols.foldLeft(df){ (acc,r) => acc.withColumn(r,col(r).cast(DecimalType(4,2))) }
df3: org.apache.spark.sql.DataFrame = [item: string, price1: decimal(4,2) ... 1 more field]
scala> df3.show
+----+------+------+
|item|price1|price2|
+----+------+------+
| 1| 79.07| 89.04|
| 2| 91.27| 1.02|
| 3| 85.60| 10.01|
+----+------+------+
scala> df3.printSchema
root
|-- item: string (nullable = true)
|-- price1: decimal(4,2) (nullable = true)
|-- price2: decimal(4,2) (nullable = true)
scala>
Помогает ли это?.
UPDATE1:
Чтение файла CSV с использованием inferSchema изатем динамически приводим все двойные поля к десятичному типу (4,2).
val df = spark.read.format("csv").option("header","true").option("inferSchema","true").load("in/items.csv")
df.show
df.printSchema()
val decimal_cols = df.schema.filter( x=> x.dataType.toString == "DoubleType" ).map(x=>x.name)
// or df.schema.filter( x=> x.dataType==DoubleType )
val df3 = decimal_cols.foldLeft(df){ (acc,r) => acc.withColumn(r,col(r).cast(DecimalType(4,2))) }
df3.printSchema()
df3.show()
Результаты:
+-----+------+------+
|items|price1|price2|
+-----+------+------+
| 1| 79.07| 89.04|
| 2| 91.27| 1.02|
| 3| 85.6| 10.01|
+-----+------+------+
root
|-- items: integer (nullable = true)
|-- price1: double (nullable = true)
|-- price2: double (nullable = true)
root
|-- items: integer (nullable = true)
|-- price1: decimal(4,2) (nullable = true)
|-- price2: decimal(4,2) (nullable = true)
+-----+------+------+
|items|price1|price2|
+-----+------+------+
| 1| 79.07| 89.04|
| 2| 91.27| 1.02|
| 3| 85.60| 10.01|
+-----+------+------+