Отображение правильной легенды при создании точечного графика с палитрой - PullRequest
0 голосов
/ 01 февраля 2019

Глупый способ построения точечной диаграммы

Предположим, у меня есть данные с 3 классами, следующий код может дать мне идеальный график с правильной легендой, в которой я строю класс данных по классам.

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_blobs
import numpy as np

X, y = make_blobs()

X0 = X[y==0]
X1 = X[y==1]
X2 = X[y==2]

ax = plt.subplot(1,1,1)
ax.scatter(X0[:,0],X0[:,1], lw=0, s=40)
ax.scatter(X1[:,0],X1[:,1], lw=0, s=40)
ax.scatter(X2[:,0],X2[:,1], lw=0, s=40)
ax.legend(['0','1','2'])

enter image description here

Лучший способ построения точечной диаграммы

Однако, если у меня есть набор данных с 3000 классами, вышеуказанный методбольше не работает(Вы не ожидаете, что я напишу 3000 строк, соответствующих каждому классу, верно?) Поэтому я придумываю следующий код для черчения.

num_classes = len(set(y))
palette = np.array(sns.color_palette("hls", num_classes))

ax = plt.subplot(1,1,1)
ax.scatter(X[:,0], X[:,1], lw=0, s=40, c=palette[y.astype(np.int)])
ax.legend(['0','1','2'])

enter image description here

Этот код идеален, мы можем построить все классы только одной строкой.Однако в этот раз легенда отображается неправильно.

Вопрос

Как сохранить правильную легенду при построении графиков с помощью следующего?

ax.scatter(X[:,0], X[:,1], lw=0, s=40, c=palette[y.astype(np.int)])

Ответы [ 2 ]

0 голосов
/ 01 февраля 2019

Почему бы просто не сделать следующее?

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_blobs
import numpy as np

X, y = make_blobs()
ngroups = 3

ax = plt.subplot(1, 1, 1)
for i in range(ngroups):
    ax.scatter(X[y==i][:,0], X[y==i][:,1], lw=0, s=40, label=i)
ax.legend()
0 голосов
/ 01 февраля 2019

plt.legend() работает лучше всего, когда на сюжете несколько «художников».Это ваш первый пример, поэтому вызов plt.legend(labels) работает без усилий.

Если вас беспокоит написание большого количества строк кода, вы можете воспользоваться for циклами.

Как видно из этого примера, используя 5 классов:

import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import numpy as np

X, y = make_blobs(centers=5)
ax = plt.subplot(1,1,1)

for c in np.unique(y):
    ax.scatter(X[y==c,0],X[y==c,1],label=c)

ax.legend()

enter image description here

np.unique() возвращает отсортированный массив уникальных элементов y, просматривая их и нанося каждый класс с собственным художником, plt.legend() может легко предоставить легенду.

Редактировать:

Вы также можете назначать метки для графиков по мере их создания, чтовероятно, безопаснее.

plt.scatter(..., label=c), за которыми следует plt.legend()

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...