Pylab: сопоставить метки с цветами - PullRequest
5 голосов
/ 16 марта 2012

Я только начинаю со стека scipy. Я использую набор данных радужной оболочки в версии CSV. Я могу просто загрузить его, используя:

iris=numpy.recfromcsv("iris.csv")

и нанесите на карту:

pylab.scatter(iris.field(0), iris.field(1))
pylab.show()

Теперь я хотел бы также построить классы, которые хранятся в iris.field(4):

chararray(['setosa', ...], dtype='|S10')

Какой элегантный способ отобразить эти строки в цвета для черчения? scatter(iris.field(0), iris.field(1), c=iris.field(4)) не работает (из документов он ожидает значения с плавающей запятой или цветовую карту). Я не нашел элегантного способа автоматического создания цветовой карты.

cols = {"versicolor": "blue", "virginica": "green", "setosa": "red"}
scatter(iris.field(0), iris.field(1), c=map(lambda x:cols[x], iris.field(4)))

делает примерно то, что я хочу, но мне не очень нравится ручная цветовая спецификация.

Редактировать : немного более элегантный вариант последней строки:

scatter(iris.field(0), iris.field(1), c=map(cols.get, iris.field(4)))

Ответы [ 2 ]

5 голосов
/ 16 марта 2012

Является ли способ элегантным или нет, это несколько субъективно. Я лично нахожу ваши подходы лучше, чем «matplotlib». Из color модуля matplotlib:

Colormapping обычно состоит из двух этапов: массив данных является первым отображается на диапазон 0-1 с использованием экземпляра Normalize или подкласс; затем это число в диапазоне 0-1 сопоставляется с цветом с помощью экземпляр подкласса Colormap.

Что касается вашей проблемы, я понимаю, что вам нужен подкласс Normalize, который принимает строки и отображает их в 0-1.

Вот пример, который наследуется от Normalize для создания подкласса TextNorm, который используется для преобразования строки в значение от 0 до 1. Эта нормализация используется для получения соответствующего цвета.

import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import numpy as np
from numpy import ma

class TextNorm(Normalize):
    '''Map a list of text values to the float range 0-1'''

    def __init__(self, textvals, clip=False):
        self.clip = clip
        # if you want, clean text here, for duplicate, sorting, etc
        ltextvals = set(textvals)
        self.N = len(ltextvals)
        self.textmap = dict(
            [(text, float(i)/(self.N-1)) for i, text in enumerate(ltextvals)])
        self.vmin = 0
        self.vmax = 1

    def __call__(self, x, clip=None):
        #Normally this would have a lot more to do with masking
        ret = ma.asarray([self.textmap.get(xkey, -1) for xkey in x])
        return ret

    def inverse(self, value):
        return ValueError("TextNorm is not invertible")

iris = np.recfromcsv("iris.csv")
norm = TextNorm(iris.field(4))

plt.scatter(iris.field(0), iris.field(1), c=norm(iris.field(4)), cmap='RdYlGn')
plt.savefig('textvals.png')
plt.show()

Это производит:

enter image description here

Я выбрал цветовую карту «RdYlGn», чтобы было легко различать три типа точек. Я не включил функцию clip в состав __call__, хотя это возможно с некоторыми изменениями.

Традиционно вы можете проверить нормализацию метода scatter, используя ключевое слово norm, но scatter проверяет ключевое слово c, чтобы увидеть, хранит ли оно строки, и если это так, то предполагается, что вы передаете в цветах как их строковые значения, например «Красный», «Синий» и т. Д. Так что набрать plt.scatter(iris.field(0), iris.field(1), c=iris.field(4), cmap='RdYlGn', norm=norm) не удается. Вместо этого я просто использую TextNorm и «оперирую» на iris.field(4), чтобы получить массив значений в диапазоне от 0 до 1.

Обратите внимание, что для строки, не указанной в списке, возвращается значение -1 textvals. Вот где маскировка пригодится.

4 голосов
/ 17 марта 2012

Для чего бы это ни стоило, вы обычно делаете что-то более подобное в этом случае:

import numpy as np
import matplotlib.pyplot as plt

iris = np.recfromcsv('iris.csv')
names = set(iris['class'])

x,y = iris['sepal_length'],  iris['sepal_width']

for name in names:
    cond = iris['class'] == name
    plt.plot(x[cond], y[cond], linestyle='none', marker='o', label=name)

plt.legend(numpoints=1)
plt.show()

enter image description here

Нет ничего плохого в том, что предложил @Yann, но scatter лучше подходит для непрерывных данных.

Проще полагаться на цветовой цикл осей и просто вызывать график несколько раз (вместо коллекции вы также получаете отдельных художников, что хорошо для дискретных данных, таких как эта).

По умолчанию цветовой цикл для осей: синий, зеленый, красный, голубой, пурпурный, желтый, черный.

После 7 вызовов на plot, он будет переключаться обратно по этим цветам, поэтому, если у вас есть больше элементов, вам нужно установить его вручную (или просто указать цвет в каждом вызове plot с использованием интерполированной цветовой шкалы, аналогичной предложенной @Yann выше.

...