Непоследовательное построение графиков с радиальной гистограммой в matplotlib - PullRequest
1 голос
/ 21 октября 2019

Я пытаюсь создать круговую гистограмму в Python / matplotlib. У меня есть код, хорошо работающий на моем рабочем столе (использующий Spyder), но вывод программы совершенно разный на разных компьютерах даже в одной и той же версии Python (3.5) и Spyder (3.3.1). Круговая гистограмма разбивается на один сегмент, и я не понимаю, почему.

Как это выглядит на моем рабочем столе (должно выглядеть)

Как это выглядит на других компьютерах

Код считывает данные из электронной таблицы. Он в основном извлекает значения стран и процентов, а затем регионы используются для цветов.

Пример ввода: data.xlsx

Я попытался убедиться, что версии Pythonодинаковы и прошли через каждую строку кода, и я думаю, что это как-то связано с моим призывом создать тэтагрид (примерно 10 строк снизу), который ведет себя по-разному в зависимости от компьютера. Вот где диаграмма переходит из циклического в странное поведение сегмента.

import pandas as pd
import pylab as pl
import numpy as np
import matplotlib.pyplot as plt

#plot params
bottom = 0
max_height = 10
DPI = 400
figsize = DPI*8
MaxPercentage = 0.12
MinPercentage = -0.12
Padlength = 45 #This is the maximum country name length

for sheetcount in range(0,2):

    #read in data
    xl = pd.ExcelFile('data.xlsx')
    sheet_names = xl.sheet_names
    currentsheet = sheet_names[sheetcount]
    datafortitle = pd.read_excel('data.xlsx',sheetname = currentsheet)
    ChartTitle = datafortitle.iloc[0,1]
    rawdata = pd.read_excel('data.xlsx',sheetname = currentsheet,usecols = [0,1,2,3], skiprows = 3)
    rawdata.columns = ['regions','countries','count','value']

    #populate regions (remove gaps)
    for i in range(0,len(rawdata.index)):
        if pd.isnull(rawdata.get_value(i, 'regions')): 
            rawdata.at[i, 'regions'] = rawdata.get_value(i-1,'regions') 


    #remove rows where count isnt 3
    rawdata = rawdata[~pd.isnull(rawdata).any(axis=1)]

    # Get names of indexes for which columns dont have enough data
    indexNames = rawdata[ rawdata['count'] <= 2].index

    # Delete these row indexes
    rawdata.drop(indexNames , inplace=True)
    rawdata = rawdata.reset_index(drop = True)

    fulldata = rawdata

    #cut into respective groups
    Regions = fulldata.iloc[:,0].astype(np.str)
    Countries_Lower = fulldata.iloc[:,1].astype(np.str)
    Countries = Countries_Lower.str.upper()
    ActualSalary = fulldata.iloc[:,3].astype(np.float)
    ActualSalaryStr = fulldata.iloc[:,3].astype(np.str)

    N = len(ActualSalary)

    #Set colours for regions
    Colours = ["" for x in range(N)]
    i = -1
    for x in Regions:
        i += 1
        x = str(x)
        if x == 'Africa':
            Colours[i] = '#490E3D'
        elif x == 'APAC':
            Colours[i] = '#E97F02'
        elif x == 'Europe':
            Colours[i] = '#F8CA00'        
        elif x == 'LATAM':
            Colours[i] = '#88C100'
        elif x == 'Middle East':
            Colours[i] = '#30C4C9'
        elif x == 'North America':
            Colours[i] = '#E40D2C'
        else:
            Colours[i] = 'black'        


    TrimmedSalary = ActualSalary.copy()
    TrimLocation = np.zeros(len(ActualSalary))

    #Trim values to be plotted to be less than or equal to MaxPercentage
    for i in range(0,len(ActualSalary)):
        if TrimmedSalary[i] > MaxPercentage:
            TrimmedSalary[i] = MaxPercentage
            TrimLocation [i] = 1
        if TrimmedSalary[i] < MinPercentage:
            TrimmedSalary[i] = MinPercentage
            TrimLocation [i] = 2


    theta = np.linspace(0.0, 2 * np.pi, N, endpoint=False) - np.pi/N
    r = max_height*TrimmedSalary
    width = (2*np.pi) / N

    with plt.rc_context({'font.cursive':'Textile','axes.edgecolor':'#8a8a8a', 'xtick.color':'#8a8a8a', 'figure.facecolor':'white'}):# 'font.family':'cursive','font.sans-serif':'Geneva'}):

        f = plt.figure(figsize=(figsize/DPI,figsize/DPI))                 
        ax = f.add_subplot(111, polar = True)

        #remove y axis clutter, setup x axis.
        ax.xaxis.grid(color = '#8a8a8a', linestyle = '-')
        ax.yaxis.grid(False)
        ax.set_yticklabels([])
        ax.set_axisbelow(True)
        ax.FontName = 'Sans'

        #add country labels
        ax.set_xticks(np.linspace(0,2*np.pi,N+1))

        bars = ax.bar(theta, r, width=width, bottom=bottom, edgecolor = "none", linewidth = 0)

        # Use custom colors and opacity
        i = 0
        for r, bar in zip(r, bars):
            bar.set_facecolor(Colours[i])
            bar.set_alpha(1.0)
            i += 1

        Colours.append('black')

        #Fix plot size
        ax.set_ylim([MinPercentage*max_height,MaxPercentage*max_height])

        ############################################
        #Plot labels
        Valuelabels = ["%.2f" % x for x in (ActualSalary*100)] 
        Valuelabels = [s + " %" for s in Valuelabels]
        maxlen = len(max(Valuelabels, key=len))

        angles = np.linspace(0,2*np.pi - 2*np.pi/N,N)
        angles[np.cos(angles) < 0] = angles[np.cos(angles) < 0] + np.pi
        angles = np.rad2deg(angles)
        angleplace = np.linspace(0,2*np.pi,N+1)
        angleplace = np.rad2deg(angleplace)

        for i in range(0,N):
            if 90 < angleplace[i] <= 270:
                Valuelabels[i] = Valuelabels[i].rjust(maxlen+1)      
            else:
                Valuelabels[i] = Valuelabels[i].ljust(maxlen+1)


        CombinedLabels = ["" for x in range(N)]

        for i in range(0,N):
            if 0 <= angleplace[i] <= 90:
                CombinedLabels[i] = Valuelabels[i] + Countries[i]
            elif 90 < angleplace[i] <= 270:
                CombinedLabels[i] = Countries[i] + Valuelabels[i]          
            else:
                CombinedLabels[i] = Valuelabels[i] + Countries[i]              


        #Pad to the same length
        maxlen = Padlength        
        for i in range(0,N):
            if 90 < angleplace[i] <= 270:
                CombinedLabels[i] = CombinedLabels[i].rjust(maxlen)      
            else:
                CombinedLabels[i] = CombinedLabels[i].ljust(maxlen)    

        ax.set_xticklabels(CombinedLabels, fontsize=10,)
        plt.gcf().canvas.draw()
        labels = []

        radius = 4    
        i = 0
        for label, angle in zip(ax.get_xticklabels(), angles):   
            x = i/N*(2*np.pi)
            #y = 3.1
            y = 2.77

            lab = ax.text(x,y, label.get_text(), color = Colours[i],
                          ha='center', va='center',
                          weight = 'bold', fontsize = 10,
                          fontproperties=prop)
            lab.set_rotation(angle)
            labels.append(lab)
            i += 1


        #Plot gridlines - SOMETHING GOES WRONG HERE
        thetagrid = np.linspace(0.0, 2*np.pi, N+1) - np.pi/N
        (lines,labelsempty)=plt.thetagrids(np.degrees(thetagrid))

        ax.set_xticklabels([])
        ax.set_axisbelow(False)

        #Add central circle
        circle = pl.Circle((0, 0), 0.4, transform=ax.transData._b, fc="white",ec = '#8a8a8a', zorder = 10)
        ax.add_artist(circle)        

        plt.savefig(ChartTitle, dpi = DPI, bbox_inches='tight')
        #plt.show()

Очень жаль весь код, но не знал, как его сократить.

Кто-нибудь знаетчего-то, чего мне не хватает? Разница в версиях возможно? Обновление для thetagrid? Или я просто использую все это совершенно неправильно.

Я был бы очень признателен за любую вашу помощь, поскольку я пытался исправить это уже пару недель и просто зашел в тупик. .

Большое спасибо.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...