Отрегулируйте участок с тепловыми картами - PullRequest
0 голосов
/ 03 января 2019

У меня есть Subplots с Матрицами смешения, которые представлены с HeatMap.

Subplot with Heatmaps Я хотел бы настроить график так, чтобы он был более читабельным и выполнял такие вещи, как:

1) Добавьте один большой заголовок над столбцами «Цели»

2) Добавьте один большой Ylabel «Предсказания»

3) для каждого столбца есть только одна большая легенда,так как они показывают одно и то же

4), для каждого столбца добавьте имена столбцов ['Train CM', 'Train Norm CM', 'Validation CM', 'Validation Norm CM'] и имена строк [f'Epoch {i}' for i in range(n_epoch)].Я сделал, как в здесь , но работает только для столбцов, а не для строк, я не знаю почему.

Мой код:

cols = ['Train CM', 'Train Norm CM', 'Validation CM', 'Validation Norm CM']
rows = [f'Epoch {i}' for i in range(n_epoch)]

f, axes  = plt.subplots(nrows = n_epoch, ncols = 4, figsize=(40, 30))
for ax, col in zip(axes [0], cols):
    ax.set_title(col, size='large')

for ax, row in zip(axes[:,0], rows):
    ax.set_ylabel(row, rotation=0, size='large')

f.tight_layout()

for e in range(n_epoch):
    for c in range(4):
        # take conf matrix from lists cm_Train or cm_Validation of ConfusionMatrix() objects
        if c == 0:
            cm = np.transpose(np.array([list(item.values()) for item in cm_Train[e].matrix.values()]))
        elif c == 1:
            cm = np.transpose(np.array([list(item.values()) for item in cm_Train[e].normalized_matrix.values()]))
        elif c == 2:
        cm = np.transpose(np.array([list(item.values()) for item in cm_Validation[e].matrix.values()]))
    else:
        cm = np.transpose(np.array([list(item.values()) for item in cm_Validation[e].normalized_matrix.values()]))
    sns.heatmap(cm, annot=True, fmt='g', ax = axes[e, c], linewidths=.3)

1 Ответ

0 голосов
/ 03 января 2019

Я представляю решение с пустыми участками, потому что у меня нет ваших данных.Это то, что вы хотите:

n_epoch = 4
cols = ['Train CM', 'Train Norm CM', 'Validation CM', 'Validation Norm CM']
rows = [f'Epoch {i}' for i in range(n_epoch)]

f, axes  = plt.subplots(nrows = n_epoch, ncols = 4, figsize=(12, 8))

f.text(0, 0.5, 'Predictions', ha='center', va='center', fontsize=20, rotation='vertical')
plt.suptitle("One big title", fontsize=18, y=1.05)

for ax, col in zip(axes [0], cols):
    ax.set_title(col, size='large')

for ax, row in zip(axes[:, 0], rows):
    ax.set_ylabel(row, size='large')

plt.tight_layout()    

enter image description here

Размещение цветных полос : Здесь вы помещаете цветные полосы, охватывающие все строки длякаждый столбец.Однако здесь tight_layout() несовместим, поэтому вам придется отключить его.

f, axes  = plt.subplots(nrows = n_epoch, ncols = 4, figsize=(12, 8))

for i, ax in enumerate(axes.flat):
    im = ax.imshow(np.random.random((20,20)), vmin=0, vmax=1)
    if i%4 == 0:
        f.colorbar(im, ax=axes[:,int(i/4)].ravel().tolist(), aspect=30, pad=0.05)    

f.text(0.08, 0.5, 'Predictions', ha='center', va='center', fontsize=20, rotation='vertical')
plt.suptitle("One big title", fontsize=18)

for ax, col in zip(axes [0], cols):
    ax.set_title(col, size='large')

for ax, row in zip(axes[:, 0], rows):
    ax.set_ylabel(row, size='large')

enter image description here

...