Отобразить сложенную гистограмму, циклически перебирая кластеры внутри df - PullRequest
0 голосов
/ 19 апреля 2019

У меня есть набор данных, который включает в себя все средние значения бейсбольных игроков. Я назначаю каждого игрока в этом наборе данных случайным образом в кластер. Теперь я хочу визуально отобразить каждый кластер в виде гистограммы. Я использую следующее:

import matplotlib.pyplot as plt

def chart(k=2):
    x = np.arange(0, 0.4, 0.001)
    for j in range(k):
        cluster = df.loc[df['cluster'] == j].reset_index()
        plt.hist(cluster['Average'], bins=50, density=1, stacked=True)
    plt.xlim(0, 0.4)
    plt.xlabel('Batting Average')
    plt.ylabel('Density')
    plt.show()

Это дает мне следующий вывод: enter image description here

Однако я бы хотел увидеть следующее:

enter image description here

Я создал эту диаграмму, разделив набор данных «жестко закодировано». В идеале я хочу сделать это динамически, создав цикл. Как я могу также добавить легенду с именами кластеров и указать цвет для каждого кластера? Опять все в петле. K также может быть 10, например. Заранее спасибо

1 Ответ

1 голос
/ 19 апреля 2019

Не предоставляя данные и Минимальный, полный и проверяемый пример людям, прежде чем задать вопрос, затрудняет решение вашей проблемы.Это то, что вы должны иметь в виду в следующий раз.Тем не менее, вот один способ, который должен работать для вас.Идея состоит в том, чтобы создать объект оси ax и передать его для построения обеих гистограмм на одном и том же рисунке.Затем вы можете изменить метки, ограничения и т. Д. Вне функции после построения всего графика.

PS: Как отметил Пол Х в комментариях ниже, DataFrame df и имена столбцов должны также передаваться в качестве аргументов функции диаграммы, чтобы сделать ее более устойчивой

import matplotlib.pyplot as plt

def chart(ax1, k=2):
    x = np.arange(0, 0.4, 0.001)
    for j in range(k):
        cluster = df.loc[df['cluster'] == j].reset_index()
        ax1.hist(cluster['Average'], bins=50, density=1, stacked=True)
    return ax1

fig, ax = plt.subplots()

ax = chart(ax, k=2)    
plt.xlim(0, 0.4)
plt.xlabel('Batting Average')
plt.ylabel('Density')
plt.show()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...