Numpy жалобы на неоднозначный массив: ValueError: истинное значение - PullRequest
0 голосов
/ 27 мая 2020

У меня есть минимальный код в Python 3, который использует numpy и функцию apply_along_axis. Я не могу понять причину появления этой ошибки:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Предоставление прямой формулы внутри lambda работает. Как только я использую другую функцию, я получаю эту ошибку. Должен ли я вернуть что-то еще?


Минимальный код:

import numpy as np

def logn(x, b):
    return np.log(x)/np.log(b)
def h(x, b):
    if x == 0:
        return 0
    else:
        return -x*logn(x, b)

p = np.array([0.00000000e+00, 9.99997956e-01, 2.04440466e-06])
print(np.apply_along_axis(lambda _e: h(_e, 3), -1, p))

1 Ответ

2 голосов
/ 27 мая 2020

Посмотрите, что apply_along_axis передает вашей функции:

In [99]: def foo(x): 
    ...:     print(x) 
    ...:     return x 
    ...:                                                                                 
In [100]: np.apply_along_axis(foo, -1, p)                                                
[0.00000000e+00 9.99997956e-01 2.04440466e-06]
Out[100]: array([0.00000000e+00, 9.99997956e-01, 2.04440466e-06])

В случае 1d-массива, он передает весь массив сразу. Он не повторяется в этом измерении. В этом вся цель apply_along_axis - передать 1d массивы вашей функции.

Судя по другим SO apply_along_axis не очень полезен, часто дает проблемы. Это не быстрее, чем более явная итерация. Для 3d (или выше) это может упростить итерацию (по двум «другим» осям) (но опять же не быстрее).

Для 1d p это проще:

In [102]: [h(_e,3) for _e in p]                                                          
Out[102]: [0, 1.8605270777946112e-06, 2.4378506521338855e-05]

Неитеративный подход заключается в использовании булевой маски для выбора p, используемых в вычислении. Таким образом, вам не нужно использовать скалярное выражение if:

In [106]: mask = p!=0                                                                    
In [107]: mask                                                                           
Out[107]: array([False,  True,  True])
In [108]: p1 = p[mask]                                                                   
In [109]: res = np.zeros(p.shape)                                                        
In [110]: res[mask] = -p1*logn(p1,3)                                                     
In [111]: res                                                                            
Out[111]: array([0.00000000e+00, 1.86052708e-06, 2.43785065e-05])

ufunc например np.log принимает параметр where, который можно использовать для обхода неверных входных значений:

In [114]: -p * np.log(p, where=(p!=0), out=np.zeros(p.shape))/np.log(3)                  
Out[114]: array([-0.00000000e+00,  1.86052708e-06,  2.43785065e-05])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...