Как добавить строки сравнения в lmplot в seaborn? - PullRequest
1 голос
/ 09 мая 2020

Я хотел бы объединить следующие lmplots, которые у меня есть. В частности, красные линии представляют собой средние значения для каждого сезона, и я хочу разместить их на соответствующих lmplots с другими данными, вместо того, чтобы разделять их. Вот мой код (обратите внимание, ограничения по осям не работают, потому что второй lmplot портит его. Он работает, когда я просто рисую начальные данные):

ax = sns.lmplot(data=data, x='air_yards', y='cpoe',col='season', lowess = True, scatter_kws={'alpha':.6, 'color': '#4F2E84'}, line_kws={'alpha':.6, 'color': '#4F2E84'})

ax = sns.lmplot(data=avg, x='air_yards', y= 'cpoe',lowess=True, scatter=False, line_kws={'linestyle':'--', 'color': 'red'}, col = 'season')

axes.set_xlim([-5,30])
axes.set_ylim([-25,25])

ax.set(xlabel='air yards')

И вот результат. Проще говоря, я хочу взять эти красные линии и поместить их на соответствующие годовые графики выше. Спасибо! And here is the output

1 Ответ

2 голосов
/ 09 мая 2020

Не уверен, возможно ли это так, как вы хотите, поэтому может быть что-то вроде:

import matplotlib.pyplot as plt
import seaborn as sns

#dummy example
data = pd.DataFrame({'air_yards': range(1,11), 
                     'cpoe': range(1,11), 
                     'season': [1,2,3,2,1,3,2,1,3,2]})
avg = pd.DataFrame({'air_yards': [1, 10]*3, 
                    'cpoe': [2,2,5,5,8,8], 
                    'season': [1,1,2,2,3,3]})

# need this info
n = data["season"].nunique()

# create the number of subplots
fig, axes = plt.subplots(ncols=n, sharex=True, sharey=True)

# now you need to loop through unique season
for ax, (season, dfg) in zip(axes.flat, data.groupby("season")):
    # set title
    ax.set_title(f'season={season}')

    # create the replot for data
    sns.regplot("air_yards", "cpoe", data=dfg, ax=ax, 
                lowess = True, scatter_kws={'alpha':.6, 'color': '#4F2E84'}, 
                line_kws={'alpha':.6, 'color': '#4F2E84'})

    # create regplot for avg
    sns.regplot("air_yards", "cpoe", data=avg[avg['season'].eq(season)], ax=ax, 
                lowess=True, scatter=False, 
                line_kws={'linestyle':'--', 'color': 'red'})

plt.show()

вы получите enter image description here

...