использовать одно и то же цветовое сопоставление для строк таблицы и мультидистплот - PullRequest
1 голос
/ 09 июля 2020

У меня есть фрейм данных, из которого я беру несколько групп данных и отображаю их как отображение на том же рисунке (с наложением). Я также показываю таблицу, в которой обобщены некоторые данные по каждой группе. Я хотел бы отображать каждую строку в таблице (= каждую группу) тем же цветом, что и соответствующий цвет диспетчеризации. Я попытался определить общую цветовую карту как для таблицы, так и для диспетчера, однако диспетчер выдает ошибку:

in distplot
    if kde_color != color:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Process finished with exit code 1

Вот код:


fig, (ax_plot, ax_table) = plt.subplots(nrows=2, figsize=(11.69, 8.27),
                                            gridspec_kw=dict(height_ratios=[3, 1]) )    
ax_table.axis("off")

item_types = item_df['item_type'].unique()
columns = ('item type', 'Average DR', 'Percent DR passed 50%', 'Percent DR passed 60%', 'Percent DR passed 70%',
           'Percent DR passed 80%')
cell_text = []
table_colors = plt.cm.BuPu(np.linspace(0, 0.5, len(item_types)))
i=0
    
for item_type in item_types:
    item_dr = item_df[item_df['item_type'] == item_type]['interesting_feature'].values
    color = table_colors[i, 0:3]
    sns.distplot(item_dr, hist=False, label=item_type, ax=ax_plot, color=mcolors.rgb_to_hsv(color))
    i += 1
    avg_dr = np.mean(item_dr)
    pass50 = len(item_dr[item_dr > 0.5]) / len(item_dr)
    pass60 = len(item_dr[item_dr > 0.6]) / len(item_dr)
    pass70 = len(item_dr[item_dr > 0.7]) / len(item_dr)
    pass80 = len(item_dr[item_dr > 0.8]) / len(item_dr)

    cell_text.append([str(item_type), str(avg_dr), str(pass50), str(pass60), str(pass70), str(pass80)])
item_table = ax_table.table(cellText=cell_text,
                            colLabels=columns,
                            loc='center',
                            fontsize=20,
                            rowColours=table_colors)

1 Ответ

0 голосов
/ 10 июля 2020

Во-первых, преобразование в hsv, как в mcolors.rgb_to_hsv(color), не выглядит очень полезным.

Теперь основная проблема, по-видимому, заключается в передаче цвета в виде списка или массива numpy ([1, 0, 0]) смущает sns.distplot(..., color=color). Многие функции seaborn допускают использование одного цвета или списка цветов и не различают guish цвет, передаваемый как значения RGB, и массив. Обходной путь - преобразовать список в кортеж: sns.distplot(..., color=tuple(color)).

Вот минимальный пример:

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

num_colors = 5
table_colors = plt.cm.BuPu(np.linspace(0, 0.5, num_colors))

fig, (ax_plot, ax_table) = plt.subplots(nrows=2)
for i in range(num_colors):
    color = table_colors[i, 0:3]
    # sns.distplot(np.random.normal(0, 1, 100), hist=False, color=color) # gives an error
    sns.distplot(np.random.normal(0, 1, 100), hist=False, color=tuple(color), ax=ax_plot)

columns = list('abcdef')
num_columns = len(columns)
ax_table.table(cellText=np.random.randint(1, 1000, size=(num_colors, num_columns)) / 100,
               colLabels=columns, loc='center', fontsize=20,
               cellColours=np.repeat(table_colors, num_columns, axis=0).reshape(num_colors, num_columns, -1))
ax_table.axis('off')
plt.tight_layout()
plt.show()

example plot

To change the color of the text, you can loop through the cells of the table. As these particular colors are not very visible on a white background, the cell background could be set to black.

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

num_colors = 5
table_colors = plt.cm.BuPu(np.linspace(0, 0.5, num_colors))

fig, (ax_plot, ax_table) = plt.subplots(nrows=2)
for i in range(num_colors):
    color = table_colors[i, :]
    # sns.distplot(np.random.normal(0, 1, 100), hist=False, color=color) # gives an error
    sns.distplot(np.random.normal(0, 1, 100), hist=False, color=tuple(color), ax=ax_plot)

columns = list('abcdef')
num_columns = len(columns)
table = ax_table.table(cellText=np.random.randint(1, 1000, size=(num_colors, num_columns)) / 100,
                       colLabels=columns, loc='center', fontsize=20)
for i in range(num_colors):
    for j in range(num_columns):
        table[(i+1, j)].set_color('black')   # +1: skip the table header
        table[(i+1, j)].get_text().set_color(table_colors[i, :])
ax_table.axis('off')
plt.tight_layout()
plt.show()

изменение цвета текста

...