Я думаю, что точка, в которой вы запутались, это:
matrix.collect()[0]["pearson({})".format(vector_col)].values
Вызов .values плотной матрицы дает вам список всех значений, но на самом деле вы ищете список, представляющий матрицу корреляции.
import matplotlib.pyplot as plt
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation
columns = ['col1','col2','col3']
myGraph=spark.createDataFrame([(1.3,2.1,3.0),
(2.5,4.6,3.1),
(6.5,7.2,10.0)],
columns)
vector_col = "corr_features"
assembler = VectorAssembler(inputCols=['col1','col2','col3'],
outputCol=vector_col)
myGraph_vector = assembler.transform(myGraph).select(vector_col)
matrix = Correlation.corr(myGraph_vector, vector_col)
До сих пор это был в основном ваш код. Вместо вызова .values вы должны использовать .toArray (). Tolist (), чтобы получить список списков, представляющих матрицу корреляции:
matrix = Correlation.corr(myGraph_vector, vector_col).collect()[0][0]
corrmatrix = matrix.toArray().tolist()
print(corrmatrix)
Выход:
[[1.0, 0.9582184104641529, 0.9780872729407004], [0.9582184104641529, 1.0, 0.8776695567739841], [0.9780872729407004, 0.8776695567739841, 1.0]]
Преимущество этого подхода заключается в том, что вы можете легко превратить список списков в фрейм данных:
df = spark.createDataFrame(corrmatrix,columns)
df.show()
Выход:
+------------------+------------------+------------------+
| col1| col2| col3|
+------------------+------------------+------------------+
| 1.0|0.9582184104641529|0.9780872729407004|
|0.9582184104641529| 1.0|0.8776695567739841|
|0.9780872729407004|0.8776695567739841| 1.0|
+------------------+------------------+------------------+
Чтобы ответить на ваш второй вопрос. Просто одно из многих решений для построения тепловой карты (например, , или , , еще лучше с seaborn ).
def plot_corr_matrix(correlations,attr,fig_no):
fig=plt.figure(fig_no)
ax=fig.add_subplot(111)
ax.set_title("Correlation Matrix for Specified Attributes")
ax.set_xticklabels(['']+attr)
ax.set_yticklabels(['']+attr)
cax=ax.matshow(correlations,vmax=1,vmin=-1)
fig.colorbar(cax)
plt.show()
plot_corr_matrix(corrmatrix, columns, 234)