Мультиклассовый график geom_density не работает - PullRequest
0 голосов
/ 14 марта 2019

У меня проблемы с использованием plotnine: я не могу создать графику с 3-мя классами (разделенными цветом).

import pandas as pd
import numpy as np

from plotnine import *

path = '/home/punkproger/workspace/MyWorkPython/TestWork/galaxy_identificator/data/train.csv'

df = pd.read_csv(path)

my_plot = ggplot(data=df[:30000], mapping=aes(x='ra', fill='class', color='class')) + geom_density( alpha=0.7)
print(my_plot)

В каждых 10 тыс. Сэмплов есть новый «класс» (0-2).

Результат будет:

this graph

Но если я изменю количество выборок на 10k (есть только 1 класс):

import pandas as pd
import numpy as np

from plotnine import *

path = '/home/punkproger/workspace/MyWorkPython/TestWork/galaxy_identificator/data/train.csv'

df = pd.read_csv(path)

my_plot = ggplot(data=df[:10000], mapping=aes(x='ra', fill='class', color='class')) + geom_density( alpha=0.7)
print(my_plot)

Результат:

this graph

Теперь у этого есть немного класса и цвета.Я хочу сделать 3 графика в одной плоскости, например:

this one

Я новичок на плотине и не вижу, что не так.Потратил много времени на попытки погуглить и решить эту проблему.

Здесь вы можете скачать данные: https://drive.google.com/file/d/1IMK1YtXG8Zl1lY8JJ12RtzDpHn65vQKi/view

1 Ответ

0 голосов
/ 14 марта 2019

Извините, я не могу загрузить ваши данные, но вот решение с имитацией данных.

import numpy as np
import pandas as pd
from plotnine import *

np.random.seed(0)

df = pd.DataFrame({'x': np.hstack((
                        np.random.normal(size=1000), 
                        np.random.normal(10, 2, size=1000), 
                        np.random.normal(-10, 2, size=1000))), 
                   'c': [0]*1000 + [1]*1000 + [2]*1000})

(ggplot(df, aes('x', color='c', fill='c')) + geom_density(alpha=0.7))

Дает это:

Взятие первых 1000 строк (соответствует c == 0):

(ggplot(df[:1000], aes('x', color='c', fill='c')) + geom_density(alpha=0.7))

Теперь создается категориальная переменная:

df['cat'] = df['c'].astype('category')
(ggplot(df, aes('x', color='cat', fill='cat')) + geom_density(alpha=0.7))

...