Цветная полоса рядом с парным графиком - PullRequest
2 голосов
/ 28 мая 2020
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
from pandas import DataFrame

# generate 2d classification dataset
X, y = make_blobs(n_samples=100, centers=3, n_features=2)
# scatter plot, dots colored by class value
df1 = DataFrame(dict(x=X[:,0], y=X[:,1], label=y))


norm = plt.Normalize(df1.x.min(), df1.x.max())
sm = plt.cm.ScalarMappable(cmap=sns.cubehelix_palette(df1['x'].max(), start=.5, rot=-.75,as_cmap=True), norm=norm)
sm.set_array([])

ax=sns.pairplot(df1,vars=['x','y'], 
            hue='label',
            palette=sns.cubehelix_palette(df1['x'].max(), start=.5, rot=-.75),
            diag_kind=None,plot_kws={"s": 50})

ax._legend.remove()

# ax.set_ylabel('WT04: Pairplot for features')
cbar=ax.fig.colorbar(sm)
m0=int(np.floor(df1.x.min()))            # colorbar min value
m4=int(np.ceil(df1.x.max()))             # colorbar max value
m1=int(1*(m4-m0)/4.0 + m0)               # colorbar mid value 1
m2=int(2*(m4-m0)/4.0 + m0)               # colorbar mid value 2
m3=int(3*(m4-m0)/4.0 + m0)               # colorbar mid value 3
cbar.set_ticks([m0,m1,m2,m3,m4])
cbar.set_ticklabels([m0,m1,m2,m3,m4])

ax.fig.suptitle('WT04: Pairplot for features',y=0.99)
plt.subplots_adjust(top=0.94)
plt.show()

plt.savefig('result.png')

Создает этот график введите описание изображения здесь

Однако я бы хотел, чтобы шкала цветов была справа от парного графика, а не внизу справа. Как это может быть сделано? Некоторые из моих парных графиков имеют другие размеры, поэтому было бы неплохо что-нибудь масштабируемое.

1 Ответ

2 голосов
/ 28 мая 2020

Вы можете попробовать добавить цветовую панель с помощью plt:

ax._legend.remove()

# remove this colorbar
# cbar=ax.fig.colorbar(sm)

m0=int(np.floor(df1.x.min()))            # colorbar min value
m4=int(np.ceil(df1.x.max()))             # colorbar max value
m1=int(1*(m4-m0)/4.0 + m0)               # colorbar mid value 1
m2=int(2*(m4-m0)/4.0 + m0)               # colorbar mid value 2
m3=int(3*(m4-m0)/4.0 + m0)               # colorbar mid value 3

ax.fig.suptitle('WT04: Pairplot for features',y=0.99)

# use this colorbar
cbar = plt.colorbar(sm, ax=ax.axes)

cbar.set_ticks([m0,m1,m2,m3,m4])
cbar.set_ticklabels([m0,m1,m2,m3,m4])

Вывод:

enter image description here

...