Спасибо, что поделились своим кодом!Вы выяснили все сложные вещи для нас.Работая с ним, я заметил несколько мелочей, которые выглядели не совсем правильно.
[FIX # 1] Тики оси не выстраивались так, как я ожидал(т. е. в приведенном выше примере вы должны быть в состоянии нарисовать вертикальную и горизонтальную линию через любую точку на всех графиках, и линии должны пересекать соответствующую точку на других графиках, но, поскольку она установлена сейчас, этого не происходит.
[FIX # 2] Если у вас есть нечетное количество переменных, с которыми вы строите график, то ось нижнего правого угла не вытягивает правильный xtics или ytic. Он просто оставляетотметки по умолчанию 0..1.
Не исправлено, но я сделал необязательным явный ввод names
, так что в переменной i устанавливается значение по умолчанию xi
диагональные позиции.
Ниже вы найдете обновленную версию своего кода, в которой рассматриваются эти две точки, в противном случае сохраняется красота вашего кода.
import itertools
import numpy as np
import matplotlib.pyplot as plt
def scatterplot_matrix(data, names=[], **kwargs):
"""
Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid.
"""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.0, wspace=0.0)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
# FIX #1: this needed to be changed from ...(data[x], data[y],...)
axes[x,y].plot(data[y], data[x], **kwargs)
# Label the diagonal subplots...
if not names:
names = ['x'+str(i) for i in range(numvars)]
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
# FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
# correct axes limits, so we pull them from other axes
if numvars%2:
xlimits = axes[0,-1].get_xlim()
ylimits = axes[-1,0].get_ylim()
axes[-1,-1].set_xlim(xlimits)
axes[-1,-1].set_ylim(ylimits)
return fig
if __name__=='__main__':
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
Спасибоснова, чтобы поделиться этим с нами. У меня естьиспользовал это много раз!О, и я перестроил часть кода main()
так, чтобы он мог быть формальным примером кода или не вызываться, если он импортируется в другой фрагмент кода.