Как создать один морской сюжет, состоящий из диаграмм рассеяния kx2 с общей легендой, объединяющей все классы - PullRequest
3 голосов
/ 26 января 2020

У меня есть несколько диаграмм разброса, с разными классами в каждом. Я хочу соединить их все вместе в сетку графиков kx2 с легендой на стороне, которая содержит все существующие классы, например, удалить легенду с отдельных графиков.

Как мне это сделать?

Здесь это 4 графика для теста 2x2

from  matplotlib.lines import  Line2D
import pandas as pd
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
from  matplotlib.lines import  Line2D

df1 = pd.DataFrame({
    "class":["a", "b", "e"],
    "time":[1,2,3],
    "score":[10, 20, 30]
})

df2 = pd.DataFrame({
    "class":["a", "c", "d"],
    "time":[0,5,10],
    "score":[5, 25, 30]
})

df3 = pd.DataFrame({
    "class":["a", "b", "c", "d", "e"],
    "time":[0,5,10,30,50],
    "score":[5, 25, 30, 40, 100]
})

df4 = pd.DataFrame({
    "class":["a", "e"],
    "time":[1,2],
    "score":[10,25]
})

def get_palette():
  pal =  {
      'a': "#4C72B0", 
      'b': "#55A868", 
      'c': "#C44E52", 
      'd': "#8172B2", 
      'e': "#CCB974", 
  }
  return pal

def get_markers():
  mark = {
      'a': Line2D.filled_markers[0], 
      'b': Line2D.filled_markers[5], 
      'c': Line2D.filled_markers[6], 
      'd': Line2D.filled_markers[7],  
      'e': Line2D.filled_markers[8], 
  }
  return mark

def get_scatterplot(source, ds_name):
  scatter = sns.scatterplot(palette=get_palette(), markers=get_markers(), 
                            edgecolor='black', alpha=0.6, x="score", y="time",
                            hue="class", style="class", s=150, 
                            data=source).set_title(ds_name)
  return scatter

scatter_df1 = get_scatterplot(df1, "df1")
plt.show()

scatter_df2 = get_scatterplot(df2, "df2")
plt.show()

scatter_df3 = get_scatterplot(df3, "df3")
plt.show()

scatter_df4 = get_scatterplot(df4, "df4")
plt.show()

Это то, что я пытаюсь сделать, основываясь на некоторых других ответах на стек

fig, axs = plt.subplots(ncols=2, nrows=2)
sns.scatterplot(palette=get_palette(), markers=get_markers(), edgecolor='black', alpha=0.6, x="score", y="time", hue="class", style="class", s=150, data=df1, ax=axs[0]).set_title("ds1")
sns.scatterplot(palette=get_palette(), markers=get_markers(), edgecolor='black', alpha=0.6, x="score", y="time", hue="class", style="class", s=150, data=df2, ax=axs[1]).set_title("ds2")
sns.scatterplot(palette=get_palette(), markers=get_markers(), edgecolor='black', alpha=0.6, x="score", y="time", hue="class", style="class", s=150, data=df3, ax=axs[2]).set_title("ds3")
sns.scatterplot(palette=get_palette(), markers=get_markers(), edgecolor='black', alpha=0.6, x="score", y="time", hue="class", style="class", s=150, data=df4, ax=axs[3]).set_title("ds4")

Но это ошибки, не знаю почему ...

AttributeError: 'numpy.ndarray' object has no attribute 'scatter'

Ответы [ 2 ]

3 голосов
/ 26 января 2020

Вы можете использовать matplotlib.pyplot.figlegend, чтобы создать одну легенду для фигуры. Без передачи аргументов это создаст легенду из " существующих художников на каждой оси. ". Если вы хотите настроить это, вы можете предоставить дескрипторы легенды и метки напрямую.

Поскольку вы явно указав цвета для каждого «класса», довольно легко составить пользовательскую легенду:

pal = get_palette()
handles = [Line2D([0], [0], color=c) for l, c in pal.items()]
labels = [l for l in pal]
plt.figlegend(handles=handles, labels=labels, loc='best')
plt.show()

Нужно добиться цели. С plt.subplots(nrows=2, ncols=2) и кодом вопроса это даст вам легенду, которая выглядит следующим образом

enter image description here

Обратите внимание, что это будет работать для любого числа классов и любое количество вспомогательных участков в любой конфигурации при условии, что все классы и их соответствующие цвета определены в pal, в противном случае необходимо будет использовать более продвинутый метод.

2 голосов
/ 26 января 2020

Чтобы устранить вашу последнюю ошибку, вам нужно передать ax в матричном стиле с индексированием строк / столбцов, так как вы указываете макет подплота с nrow и ncol:

...

fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(12,8))

sns.scatterplot(..., ax=axs[0,0]).set_title("ds1")
sns.scatterplot(..., ax=axs[0,1]).set_title("ds2")
sns.scatterplot(..., ax=axs[1,0]).set_title("ds3")
sns.scatterplot(..., ax=axs[1,1]).set_title("ds4")

plt.tight_layout()
plt.show()

OP Plot


Чтобы разрешить желаемый результат для общей легенды и даже общих осей, рассмотрите возможность объединения всех фреймов данных в один и запускайте график с помощью seaborn.FacetGrid. Одно немедленное изменение - функция маркеров, которая требует список вместо dict. ...

def get_markers_list():
  mark = [
      Line2D.filled_markers[0], 
      Line2D.filled_markers[5], 
      Line2D.filled_markers[6], 
      Line2D.filled_markers[7],  
      Line2D.filled_markers[8], 
  ]
  return mark

# COMPILE ALL DFs INTO ONE
master_df = pd.concat([df1.assign(grp="ds1"),
                       df2.assign(grp="ds2"),
                       df3.assign(grp="ds3"),
                       df4.assign(grp="ds4")])

# RUN FACET GRID
g = sns.FacetGrid(master_df, col="grp", hue="class", col_wrap=2, 
                  aspect=1.5, palette=get_palette(),
                  hue_order=list('abcde'),
                  hue_kws=dict(marker=get_markers_list()))

g = (g.map(sns.scatterplot, "score", "time", 
           edgecolor='black', alpha=0.6, s=150)
      .add_legend())

plt.show()

Proposed Plot

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...