Защита от «индекса 0 выходит за границы для оси 0 с размером 0» в Python - PullRequest
0 голосов
/ 28 октября 2019

У меня есть код, в котором я получаю конкретное распределение точек на графике функции tan(), ограниченное снизу и сверху прямыми линиями:

import matplotlib.pyplot as plt
import numpy as np
import sys
import itertools
import multiprocessing
import tqdm

ic = range(1,10)
jc = range(1,10)

paramlist = list(itertools.product(ic,jc))

def func(params):
        ic = params[0]
        jc = params[1]

        fig = plt.figure(1, figsize=(10,6))

        x_all   = np.linspace(0, 10*np.pi, 10000, endpoint=False)
        x_above = x_all[ (-0.01)*ic*x_all < np.tan(x_all) ]
        x       = x_above[ np.tan(x_above) < 0.01*jc*x_above ]
        y       = np.tan(x)
        y2      = 0.01*jc*x
        y3      = (-0.01)*ic*x

        y_up   = np.diff(y) > 0
        y_diff = np.where( y_up, np.diff(y), 0 )
        x_diff = np.where( y_up, np.diff(x), 0 )
        diffs  = np.sqrt( x_diff**2 + y_diff**2 )
        length = diffs.sum()

        numbers = [2,4,6,8,10,12,14,16,18,20]

        p2 = []
        for d in range(len(numbers)):
            cumlenth = np.cumsum(diffs)
            s = np.abs(np.diff(np.sign(cumlenth-numbers[d]))).astype(bool)
            c = np.argwhere(s)[0][0]
            p = x[c], y[c]
            p2.append(p)

        p3 = sorted(p2, key=lambda x: x[0])
        x_max = p3[len(p3)-1][0]
        p4 = sorted(p2, key=lambda x: x[1])
        y_min = p4[0][1]
        y_max = p4[len(p3)-1][1]


        for b in range(len(p2)):
            plt.scatter( p2[b][0], p2[b][1], color="crimson", s=8)

        plt.plot(x, np.tan(x))
        plt.plot(x, y2)
        plt.plot(x, y3)

        ax = plt.gca()
        ax.set_xlim([0, x_max+0.5])
        ax.set_ylim([y_min-0.5, y_max+0.5]) 

        plt.savefig('C:\\Users\\tkp\\Desktop\\wykresy_4\\i='+str(ic)+'_j='+str(jc)+'.png', bbox_inches='tight')
        plt.show()

if __name__ == '__main__':

    p = multiprocessing.Pool(4)
    for params in tqdm.tqdm(p.imap_unordered(func, paramlist), total=len(paramlist)):
        #pass
        sys.stdout.write('\r'+ str(params))
        sys.stdout.flush()

    p.close()
    p.join()

Где, например, яграфик получения:

enter image description here

Проблема в том, что если я установлю диапазон в x_all = np.linspace(0, 10*np.pi, 10000, endpoint=False) слишком мал, я получу ошибку index 0 is out of bounds for axis 0 with size 0. Как я могу защитить себя от этого? Или, может быть, в этом случае я могу установить диапазон переменных в функции "linspace"?

1 Ответ

2 голосов
/ 28 октября 2019

Где происходит эта ошибка? Это фундаментальная информация - для нас, но особенно для вас!

@ edison говорит, что это выражение argwhere. Я попытаюсь воссоздать этот шаг, начав с предположения о том, как выглядит diffs:

In [8]: x = np.ones(5)*.1                                                       
In [9]: x                                                                       
Out[9]: array([0.1, 0.1, 0.1, 0.1, 0.1])
In [10]: s = np.cumsum(x)                                                       
In [11]: s                                                                      
Out[11]: array([0.1, 0.2, 0.3, 0.4, 0.5])
In [12]: s-1                                                                    
Out[12]: array([-0.9, -0.8, -0.7, -0.6, -0.5])
In [13]: np.sign(s-1)                                                           
Out[13]: array([-1., -1., -1., -1., -1.])
In [14]: np.diff(np.sign(s-1))                                                  
Out[14]: array([0., 0., 0., 0.])
In [15]: np.abs(np.diff(np.sign(s-1)))                                          
Out[15]: array([0., 0., 0., 0.])
In [16]: np.abs(np.diff(np.sign(s-1))).astype(bool)                             
Out[16]: array([False, False, False, False])

Независимо от деталей к этому моменту, можно предположить, что s является массивомтолько с False. where находит элементы True в этом массиве;их нет.

In [17]: np.where(_)                                                            
Out[17]: (array([], dtype=int64),)

argwhere - это транспонирование этого - один столбец для каждого измерения и одна строка для каждого найденного элемента.

In [18]: np.argwhere(_)                                                         
Out[18]: array([], shape=(0, 2), dtype=int64)
In [19]: _[0]                                                                   
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-19-aa79beb95eae> in <module>
----> 1 _[0]

IndexError: index 0 is out of bounds for axis 0 with size 0

Итак, ваша первая строказащита заключается в проверке формы возвращаемого массива:

c = np.argwhere(s)
if c.shape[0]>0:
    c = c[0,0]
    p = x[c], y[c]
else:
    # what do you want to do if non of `s` are true?

Оттуда вы можете работать в обратном направлении, следя за тем, чтобы diffs или numbers были правильными и всегда находили действительные c. Но независимо от того, при использовании where или argwhere, будьте осторожны, если предположите, что он обнаружил определенное количество предметов.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...