Названия гистограмм по диагонали при использовании seaborn.PairGrid в python для генерации корреляционной матрицы - PullRequest
1 голос
/ 26 июня 2019

Потратив некоторое время на определение работы PairGrid, я почти на месте.Ниже приведен код, который генерирует график, который я хочу, с одной маленькой деталью, отсутствующей в Histfunc.То, что я хочу, это название гистограммы, нанесенной на диагонали.Как передать имена столбцов в dataframe в Histfunc?Любые идеи приветствуются.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import dcor
import random
from scipy.stats import linregress
from matplotlib import rc

font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 4}
rc('font', **font)   

def corrmat(data):
    def cm2inch(value):
        """helper function for plotting. Converts cm to inch"""
        return value/2.54

    def dist_corr(X, Y, pval=True, nruns=100):
        """ Distance correlation with p-value from bootstrapping"""
        dc = dcor.distance_correlation(X, Y)
        pv = dcor.independence.distance_covariance_test(X, Y, exponent=1.0, num_resamples=nruns)[0]
        if pval:
            return (dc, pv)
        else:
            return dc    

    def linreg(X, Y, pval=True):
        """ Linear regression"""
        r2 = linregress(X,Y)[2]**2
        pv = linregress(X,Y)[3]
        if pval:
            return (r2, pv)
        else:
            return r2               

    def scatterfunc(x, y, **kws):
        # scatterplot with spline of deg=5 in red
        plt.scatter(x, y, linewidths=1, facecolor="k", s=10, alpha = 0.5)
        spline = np.polyfit(x, y, 5)
        model = np.poly1d(spline)
        x = np.sort(x)
        plt.plot(x,model(x),'r-')

    def histfunc(x, **kws):
        #  histogram
        plt.hist(x,bins=30,color = "black", ec="white")    
        """
        vvvvvvvvvvvvvvvvvvvv
        here something like 
        plt.title(label) 
        is missing but the **kws only contain label as string not as 
        parameter contaning the column name
        ^^^^^^^^^^^^^^^^^^^
        """

    def corrfunc(x, y, dc=False, **kws):  
        # different sizes, text anc color in relation to r/d values         
        if dc:
            d, p = dist_corr(x,y) 
        else:    
            d, p = linreg(x,y)

        if d<0.25:
            pclr = 'Black'
            fontsize = 16
        elif (d>=0.25) and (d<0.5):
            pclr = 'Blue'
            fontsize = 20
        elif (d>=0.5) and (p<0.75):
            pclr = 'Orange'
            fontsize = 25
        elif (p>0.75):
            pclr = 'Red'
            fontsize = 30

        if p<0.001:
            ptext = "***"
        elif (p>=0.001) and (p<0.01):
            ptext = "**"
        elif (p>=0.01) and (p<0.05):
            ptext = "*"
        elif (p>0.05):
            ptext = "n.sig"

        ax = plt.gca()
        if dc:
            ax.annotate(''.join(['DC: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)
        else:
            ax.annotate(''.join(['r2: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)

        plt.axis('off')


    plt.figure(num=None, figsize=(cm2inch(15), cm2inch(10)), dpi=300, facecolor='w', edgecolor='k')
    g = sns.PairGrid(data, diag_sharey=False)
    g.map_upper(scatterfunc)
    g.map_diag(histfunc)
    g.map_lower(corrfunc)
    plt.tight_layout()
    plt.show()


########

data = pd.DataFrame(np.random.random([1000,10]),columns=[str(i) for i in range(10)])   
for (i,col) in enumerate(data):
    if i > 1:
        if np.random.random()>0.5:
            data[col]= data[col] * data.iloc[:,random.sample(set(np.arange(0,i)),1 )[0]]
corrmat(data)

, что он генерирует это

enter image description here

1 Ответ

0 голосов
/ 27 июня 2019

Благодаря @ImportanceOfBeingErnest прокомментируйте здесь обновленный скрипт для тех, кто может найти его полезным.Я также переключил диаграмму рассеяния на «нижнюю», чтобы метки осей стали видны.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import dcor
import random
from scipy.stats import linregress
from matplotlib import rc

font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 16}
rc('font', **font)   

def corrmat(data):
    def cm2inch(value):
        """helper function for plotting. Converts cm to inch"""
        return value/2.54

    def dist_corr(X, Y, pval=True, nruns=100):
        """ Distance correlation with p-value from bootstrapping"""
        dc = dcor.distance_correlation(X, Y)
        pv = dcor.independence.distance_covariance_test(X, Y, exponent=1.0, num_resamples=nruns)[0]
        if pval:
            return (dc, pv)
        else:
            return dc    

    def linreg(X, Y, pval=True):
        """ Linear regression"""
        r2 = linregress(X,Y)[2]**2
        pv = linregress(X,Y)[3]
        if pval:
            return (r2, pv)
        else:
            return r2               

    def scatterfunc(x, y, **kws):
        """ scatterplot with spline of deg=5 in red"""
        plt.scatter(x, y, linewidths=1, facecolor="k", s=10, alpha = 0.5)
        spline = np.polyfit(x, y, 5)
        model = np.poly1d(spline)
        x = np.sort(x)
        plt.plot(x,model(x),'r-')

    def histfunc(x, **kws):
        """ histogram"""
        plt.hist(x,bins=30,color = "black", ec="white")    

    def corrfunc(x, y, dc=False, **kws):  
        """different sizes, text anc color in relation to r/d values
           the dc parameter determines wheter distance correlation or 
           linear regression should be applied"""
        if dc:
            d, p = dist_corr(x,y) 
        else:    
            d, p = linreg(x,y)

        if d<0.25:
            pclr = 'Black'
            fontsize = 16
        elif (d>=0.25) and (d<0.5):
            pclr = 'Blue'
            fontsize = 20
        elif (d>=0.5) and (p<0.75):
            pclr = 'Orange'
            fontsize = 25
        elif (p>0.75):
            pclr = 'Red'
            fontsize = 30

        if p<0.001:
            ptext = "***"
        elif (p>=0.001) and (p<0.01):
            ptext = "**"
        elif (p>=0.01) and (p<0.05):
            ptext = "*"
        elif (p>0.05):
            ptext = "n.sig"

        ax = plt.gca()
        if dc:
            ax.annotate(''.join(['DC: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)
        else:
            ax.annotate(''.join(['r2: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)

        plt.axis('off')

    def make_diag_titles(g,titles):
        for (i,row) in enumerate(g.axes):
            g.axes[i][i].title.set_text(titles[i])
        return g
    ###
    # here the plot is put together
    plt.figure(num=None, figsize=(cm2inch(15), cm2inch(10)), dpi=300, facecolor='w', edgecolor='k')
    g = sns.PairGrid(data, diag_sharey=False)
    g.map_lower(scatterfunc)
    g.map_diag(histfunc)
    g.map_upper(corrfunc)
    g = make_diag_titles(g, data.columns)
    plt.tight_layout()
    plt.show()


########

data = pd.DataFrame(np.random.random([1000,10]),columns=[str(i) for i in range(10)])   
for (i,col) in enumerate(data):
    if i > 1:
        if np.random.random()>0.5:
            data[col]= data[col] * data.iloc[:,random.sample(set(np.arange(0,i)),1 )[0]]
corrmat(data)

enter image description here

...