Как можно объединить несколько столбцов с плавающей точкой в ​​один ArrayType (FloatType ()) в DataFrame Spark? - PullRequest
0 голосов
/ 26 февраля 2019

У меня есть искра DataFrame со многими столбцами с плавающей точкой после чтения в файл CSV.

Я хочу объединить все столбцы с плавающей точкой в ​​один ArrayType(FloatType()).

Есть идеи, как это сделать с помощью PySpark (или Scala)?

Ответы [ 3 ]

0 голосов
/ 26 февраля 2019

Нашел решение.Очень просто, но трудно найти.

float_cols = ['_c1', '_c2', '_c3', '_c4', '_c5', '_c6', '_c7', '_c8', '_c9', '_c10']

df.withColumn('combined', array([col(c) for c in float_cols]))
0 голосов
/ 26 февраля 2019

Вот еще одна версия в Scala:

data.printSchema

root
 |-- Int_Col1: integer (nullable = false)
 |-- Str_Col1: string (nullable = true)
 |-- Float_Col1: float (nullable = false)
 |-- Float_Col2: float (nullable = false)
 |-- Str_Col2: string (nullable = true)
 |-- Float_Col3: float (nullable = false)

data.show()

+--------+--------+----------+----------+--------+----------+
|Int_Col1|Str_Col1|Float_Col1|Float_Col2|Str_Col2|Float_Col3|
+--------+--------+----------+----------+--------+----------+
|       1|     ABC|     10.99|     20.99|       a|      9.99|
|       2|     XYZ|  999.1343|    9858.1|       b|    488.99|
+--------+--------+----------+----------+--------+----------+

Добавить новое поле array<float>, чтобы объединить все значения float.

val df = data.withColumn("Float_Arr_Col",array().cast("array<float>"))

Затем отфильтруйте необходимый тип данных и объедините столбцы с плавающей запятой, используя foldLeft

df.dtypes
.collect{ case (dn, dt) if dt.startsWith("FloatType") => dn }
.foldLeft(df)((accDF, c) => accDF.withColumn("Float_Arr_Col", 
                                             array_union(col("Float_Arr_Col"),array(col(c)))))
.show(false)

Вывод:

+--------+--------+----------+----------+--------+----------+--------------------------+
|Int_Col1|Str_Col1|Float_Col1|Float_Col2|Str_Col2|Float_Col3|Float_Arr_Col             |
+--------+--------+----------+----------+--------+----------+--------------------------+
|1       |ABC     |10.99     |20.99     |a       |9.99      |[10.99, 20.99, 9.99]      |
|2       |XYZ     |999.1343  |9858.1    |b       |488.99    |[999.1343, 9858.1, 488.99]|
+--------+--------+----------+----------+--------+----------+--------------------------+

Надеюсь, это поможет!

0 голосов
/ 26 февраля 2019

Если вы знаете все имя столбца с плавающей точкой.Вы можете попробовать это (scala)

val names = Seq("float_col1", "float_col2","float_col3"...."float_col10");
df.withColumn("combined", array(names.map(frame(_)):_*))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...