Как построить корреляционную тепловую карту при использовании pyspark + databricks - PullRequest
1 голос
/ 06 апреля 2019

Я изучаю pyspark в кирпичах данных. Я хочу создать тепловую карту корреляции. Допустим, это мои данные:

myGraph=spark.createDataFrame([(1.3,2.1,3.0),
                               (2.5,4.6,3.1),
                               (6.5,7.2,10.0)],
                              ['col1','col2','col3'])

А это мой код:

import pyspark
from pyspark.sql import SparkSession
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from ggplot import *
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation
from pyspark.mllib.stat import Statistics

myGraph=spark.createDataFrame([(1.3,2.1,3.0),
                               (2.5,4.6,3.1),
                               (6.5,7.2,10.0)],
                              ['col1','col2','col3'])
vector_col = "corr_features"
assembler = VectorAssembler(inputCols=['col1','col2','col3'], 
                            outputCol=vector_col)
myGraph_vector = assembler.transform(myGraph).select(vector_col)
matrix = Correlation.corr(myGraph_vector, vector_col)
matrix.collect()[0]["pearson({})".format(vector_col)].values

Пока я не смогу получить матрицу корреляции. Результат выглядит так:

enter image description here

Теперь мои проблемы:

  1. Как перенести матрицу в кадр данных? Я попробовал методы Как преобразовать DenseMatrix для запуска DataFrame в pyspark? и Как получить значения матрицы корреляции pyspark . Но это не работает для меня.
  2. Как создать корреляционную тепловую карту, которая выглядит следующим образом:

enter image description here

Потому что я только что изучил pyspark и databricks. Поэтому, пожалуйста, дайте мне как можно больше подробностей. ggplot или matplotlib подходят для моей проблемы.

1 Ответ

2 голосов
/ 07 апреля 2019

Я думаю, что точка, в которой вы запутались, это:

matrix.collect()[0]["pearson({})".format(vector_col)].values

Вызов .values ​​плотной матрицы дает вам список всех значений, но на самом деле вы ищете список, представляющий матрицу корреляции.

import matplotlib.pyplot as plt
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation

columns = ['col1','col2','col3']

myGraph=spark.createDataFrame([(1.3,2.1,3.0),
                               (2.5,4.6,3.1),
                               (6.5,7.2,10.0)],
                              columns)
vector_col = "corr_features"
assembler = VectorAssembler(inputCols=['col1','col2','col3'], 
                            outputCol=vector_col)
myGraph_vector = assembler.transform(myGraph).select(vector_col)
matrix = Correlation.corr(myGraph_vector, vector_col)

До сих пор это был в основном ваш код. Вместо вызова .values ​​вы должны использовать .toArray (). Tolist (), чтобы получить список списков, представляющих матрицу корреляции:

matrix = Correlation.corr(myGraph_vector, vector_col).collect()[0][0]
corrmatrix = matrix.toArray().tolist()
print(corrmatrix)

Выход:

[[1.0, 0.9582184104641529, 0.9780872729407004], [0.9582184104641529, 1.0, 0.8776695567739841], [0.9780872729407004, 0.8776695567739841, 1.0]]

Преимущество этого подхода заключается в том, что вы можете легко превратить список списков в фрейм данных:

df = spark.createDataFrame(corrmatrix,columns)
df.show()

Выход:

+------------------+------------------+------------------+ 
|              col1|              col2|              col3| 
+------------------+------------------+------------------+ 
|               1.0|0.9582184104641529|0.9780872729407004|
|0.9582184104641529|               1.0|0.8776695567739841| 
|0.9780872729407004|0.8776695567739841|               1.0|  
+------------------+------------------+------------------+

Чтобы ответить на ваш второй вопрос. Просто одно из многих решений для построения тепловой карты (например, , или , , еще лучше с seaborn ).

def plot_corr_matrix(correlations,attr,fig_no):
    fig=plt.figure(fig_no)
    ax=fig.add_subplot(111)
    ax.set_title("Correlation Matrix for Specified Attributes")
    ax.set_xticklabels(['']+attr)
    ax.set_yticklabels(['']+attr)
    cax=ax.matshow(correlations,vmax=1,vmin=-1)
    fig.colorbar(cax)
    plt.show()

plot_corr_matrix(corrmatrix, columns, 234)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...