matplotlib: одна и та же легенда для двух наборов данных - PullRequest
0 голосов
/ 18 апреля 2020

Я строю два набора данных в фреймах данных, используя matplotlib. Наборы данных представлены разными стилями линий. Ниже приведен код.

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
df1 = pd.DataFrame(np.random.randn(10, 16))
df2 = pd.DataFrame(np.random.randn(10, 16))


plt.figure()
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))

df1.plot(ax=axes[0], style='-', legend=True)
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title('ttl')

df2.plot(ax=axes[0], style='--', legend=True)
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title('ttl')

plt.show()

enter image description here

Однако последовательность цветов различна для разных стилей линий. например, 0 в line и 0 в dashed line имеют разные цвета. Я хотел бы попросить совета, как получить одинаковую цветовую последовательность для обоих стилей линий.

РЕДАКТИРОВАТЬ: изменение ввода на

df1 = pd.DataFrame(np.random.randn(501, 16))
df2 = pd.DataFrame(np.random.randn(5001, 16))

меняет легенду на синий enter image description here

1 Ответ

1 голос
/ 18 апреля 2020

Это немного нелепо для решения, но вы создаете список цветов той же длины, что и один из ваших фреймов данных, а затем присваиваете эти цвета каждому графику.

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
df1 = pd.DataFrame(np.random.randn(10, 6))
df2 = pd.DataFrame(np.random.randn(10, 10))

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))

# to account for different numbers of columns between dfs
if len(df2) > len(df1):
    colors = plt.cm.jet(np.linspace(0,1,len(df2)))
else:
    colors = plt.cm.jet(np.linspace(0,1,len(df1)))

df1.plot(ax=axes[0], style='-', color = colors, legend=True)
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title('ttl')

df2.plot(ax=axes[0], style='--', color = colors, legend=True)
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title('ttl')

plt.show()

enter image description here

...