Потратив некоторое время на определение работы PairGrid, я почти на месте.Ниже приведен код, который генерирует график, который я хочу, с одной маленькой деталью, отсутствующей в Histfunc.То, что я хочу, это название гистограммы, нанесенной на диагонали.Как передать имена столбцов в dataframe в Histfunc?Любые идеи приветствуются.
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import dcor
import random
from scipy.stats import linregress
from matplotlib import rc
font = {'family' : 'normal',
'weight' : 'normal',
'size' : 4}
rc('font', **font)
def corrmat(data):
def cm2inch(value):
"""helper function for plotting. Converts cm to inch"""
return value/2.54
def dist_corr(X, Y, pval=True, nruns=100):
""" Distance correlation with p-value from bootstrapping"""
dc = dcor.distance_correlation(X, Y)
pv = dcor.independence.distance_covariance_test(X, Y, exponent=1.0, num_resamples=nruns)[0]
if pval:
return (dc, pv)
else:
return dc
def linreg(X, Y, pval=True):
""" Linear regression"""
r2 = linregress(X,Y)[2]**2
pv = linregress(X,Y)[3]
if pval:
return (r2, pv)
else:
return r2
def scatterfunc(x, y, **kws):
# scatterplot with spline of deg=5 in red
plt.scatter(x, y, linewidths=1, facecolor="k", s=10, alpha = 0.5)
spline = np.polyfit(x, y, 5)
model = np.poly1d(spline)
x = np.sort(x)
plt.plot(x,model(x),'r-')
def histfunc(x, **kws):
# histogram
plt.hist(x,bins=30,color = "black", ec="white")
"""
vvvvvvvvvvvvvvvvvvvv
here something like
plt.title(label)
is missing but the **kws only contain label as string not as
parameter contaning the column name
^^^^^^^^^^^^^^^^^^^
"""
def corrfunc(x, y, dc=False, **kws):
# different sizes, text anc color in relation to r/d values
if dc:
d, p = dist_corr(x,y)
else:
d, p = linreg(x,y)
if d<0.25:
pclr = 'Black'
fontsize = 16
elif (d>=0.25) and (d<0.5):
pclr = 'Blue'
fontsize = 20
elif (d>=0.5) and (p<0.75):
pclr = 'Orange'
fontsize = 25
elif (p>0.75):
pclr = 'Red'
fontsize = 30
if p<0.001:
ptext = "***"
elif (p>=0.001) and (p<0.01):
ptext = "**"
elif (p>=0.01) and (p<0.05):
ptext = "*"
elif (p>0.05):
ptext = "n.sig"
ax = plt.gca()
if dc:
ax.annotate(''.join(['DC: ',str(np.round(d,2)),'\n\n ',ptext]),
xy=(0.3, 0.3),
xycoords=ax.transAxes,
color = pclr,
fontsize = fontsize)
else:
ax.annotate(''.join(['r2: ',str(np.round(d,2)),'\n\n ',ptext]),
xy=(0.3, 0.3),
xycoords=ax.transAxes,
color = pclr,
fontsize = fontsize)
plt.axis('off')
plt.figure(num=None, figsize=(cm2inch(15), cm2inch(10)), dpi=300, facecolor='w', edgecolor='k')
g = sns.PairGrid(data, diag_sharey=False)
g.map_upper(scatterfunc)
g.map_diag(histfunc)
g.map_lower(corrfunc)
plt.tight_layout()
plt.show()
########
data = pd.DataFrame(np.random.random([1000,10]),columns=[str(i) for i in range(10)])
for (i,col) in enumerate(data):
if i > 1:
if np.random.random()>0.5:
data[col]= data[col] * data.iloc[:,random.sample(set(np.arange(0,i)),1 )[0]]
corrmat(data)
, что он генерирует это