Невозможно скомпилировать функцию с JIT с функцией, которая принимает параметры * args - PullRequest
1 голос
/ 27 октября 2019

Я пытаюсь скомпилировать функцию, которая принимает массивный массив и кортеж параметров формы * arg, используя numba.

import numba as nb
import numpy as np

@nb.njit(cache=True)
def myfunc(t, *p):
    val = 0
    for j in range(0, len(p), 2):
        val += p[j]*np.exp(-p[j+1]*t)
    return val

T = np.arange(12)
pars = (1.0, 2.0, 3.0, 4.0)
mfunc = myfunc(T, *pars)

Однако я получаю этот результат

In [1]: run numba_test.py                                                                                                                                                                  
---------------------------------------------------------------------------                                                                                                                
TypingError                               Traceback (most recent call last)                                                                                                                
~/Programs/my-python/numba_test.py in <module>                                                                                                                                             
     12                                                                                                                                                                                    
     13 T = np.arange(12)                                                                                                                                                                  
---> 14 mfunc = myfunc(T, 1.0, 2.0, 3.0, 4.0)                                                                                                                                              

...
...                                                                                                                                                                                   
TypingError: Failed in nopython mode pipeline (step: nopython frontend)                                                                                                                    
Invalid use of Function(<built-in function iadd>) with argument(s) of type(s): (Literal[int](0), array(float64, 1d, C))                                                                    
Known signatures:                                                                                                                                                                          
 * (int64, int64) -> int64                                                                                                                                                                 
 * (int64, uint64) -> int64                                                                                                                                                                
 * (uint64, int64) -> int64                                                                                                                                                                
 * (uint64, uint64) -> uint64                                                                                                                                                              
 * (float32, float32) -> float32                                                                                                                                                           
 * (float64, float64) -> float64                                                                                                                                                           
 * (complex64, complex64) -> complex64                                                                                                                                                     
 * (complex128, complex128) -> complex128                                                                                                                                                  
 * parameterized                                                                                                                                                                           
In definition 0:                                                                                                                                                                           
    All templates rejected with literals.                                                                                                                                                  
...
...                                                                                                                                                                         
    All templates rejected without literals.                                                                                                                                               
This error is usually caused by passing an argument of a type that is unsupported by the named function.                                                                                   
[1] During: typing of intrinsic-call at /home/cshugert/Programs/my-python/numba_test.py (9)                                                                                                

File "numba_test.py", line 9:                                                                                                                                                              
def myfunc(t, *p):                                                                                                                                                                         
    <source elided>                                                                                                                                                                        
    for j in range(0, len(p), 2):                                                                                                                                                          
        val += p[j]*np.exp(-p[j+1]*t)                                                                                                                                                      
        ^                                                                                                                                                                                  

Numba поддерживает работу с кортежами, поэтому я решил, что в компиляторе jit может быть некоторая подпись. Тем не менее, я не уверен, что именно положить туда. Может ли быть так, что компиляторы numba не могут обрабатывать функции с * args в качестве параметров? Могу ли я что-нибудь сделать, чтобы моя функция работала?

1 Ответ

1 голос
/ 27 октября 2019

Давайте снова посмотрим на сообщение об ошибке

TypingError: Failed in nopython mode pipeline (step: nopython frontend)                                                                                                                    
Invalid use of Function(<built-in function iadd>) with argument(s)
 of type(s): (Literal[int](0), array(float64, 1d, C))                                                                    
Known signatures:                                                                                                                                                                          
 * (int64, int64) -> int64                                                                                                                                                                 
 * (int64, uint64) -> int64                                                                                                                                                                
 * (uint64, int64) -> int64                                                                                                                                                                
 * (uint64, uint64) -> uint64                                                                                                                                                              
 * (float32, float32) -> float32                                                                                                                                                           
 * (float64, float64) -> float64                                                                                                                                                           
 * (complex64, complex64) -> complex64                                                                                                                                                     
 * (complex128, complex128) -> complex128                                                                                                                                                  
 * parameterized  

Ошибка относится к <built-in function iadd>, что составляет +. Если вы посмотрите на ошибку, то это не из-за передачи *args, аиз-за следующего утверждения:

val += p[j]*np.exp(-p[j+1]*t)

В основном из всех упомянутых совместимых типов для + он не поддерживает добавление integer к array (см. сообщение об ошибке и известные сигнатуры для получения дополнительной информации. info).

Вы можете исправить это, установив val в виде массива нулей, используя np.zeros (см. документ здесь ).

import numba as nb
import numpy as np

@nb.njit
def myfunc(t, *p):
    val = np.zeros(12) #<------------ Set it as an array of zeros
    for j in range(0, len(p), 2):
        val += p[j]*np.exp(-p[j+1]*t)
    return val

T = np.arange(12)
pars = (1.0, 2.0, 3.0, 4.0)
mfunc_val = myfunc(T, *pars)

Youможете просмотреть код здесь в этой записной книжке Google Colab .

...