Как использовать два набора цветовых карт в Matplotlib? - PullRequest
0 голосов
/ 27 марта 2020

Я пытаюсь построить два набора данных на одном рисунке, используя независимые цветовые карты для каждого набора данных. Я создал две карты цветов cmap_blue и cmap_green, используя sns.choose_colorbrewer_palette('sequential') соответственно. Однако, когда я попытался построить наборы данных, используя следующий код, кажется, что второй sns.set_palette() переопределяет первый, что приводит к зеленым градиентам для обоих наборов данных.

sns.set_context('paper')
fig, ax = plt.subplots(figsize=[2.5, 2.5])

ax.set_xlim(0,600)
ax.set_ylim(0,15)

sns.set_palette(cmap_blue)
ax.plot(time_prot60, SFT_prot60)
ax.plot(time_prot70, SFT_prot70)
ax.plot(time_prot80, SFT_prot80)

sns.set_palette(cmap_green)
ax.plot(time_buffer60, SFT_buffer60)
ax.plot(time_buffer70, SFT_buffer70)
ax.plot(time_buffer80, SFT_buffer80)

plt.grid(True)
plt.savefig('/content/gdrive/My Drive/SVG/prot.svg', format='svg', bbox_inches = 'tight')

1 Ответ

0 голосов
/ 27 марта 2020

Проблема в том, что sns.set_palette устанавливает цветовой цикл по умолчанию для matplotlib. Но цветовой цикл также является свойством каждого ax. Следовательно, sns.set_palette влияет только на ax, созданный впоследствии. В вопросе о посте ax уже был создан ранее и уже имел свой собственный цветовой цикл.

Итак, чтобы получить желаемое поведение, палитра должна быть явно назначена ax , Функция ax.set_prop_cycle делает именно это. Вот код Обратите внимание, что я переименовал cmap_blue в palette_blue, чтобы провести различие между картой цветов (функция, которая выводит цвета) и палитрой (список цветов).

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from cycler import cycler

fig, ax = plt.subplots()

palette_blue = sns.color_palette("Blues")
palette_green = sns.color_palette("Greens")

ax.set_prop_cycle(cycler(color=palette_blue))
for i in range(1, 4):
    plt.plot(np.linspace(0, 10, 100), 10 * i + np.random.normal(0, 1, 100).cumsum())
ax.set_prop_cycle(cycler(color=palette_green))
for i in range(4, 7):
    plt.plot(np.linspace(0, 10, 100), 10 * i + np.random.normal(0, 1, 100).cumsum())

plt.show()

enter image description here

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...