Numba Jit возвращает ошибку при столкновении с цифровым оцифровыванием, несмотря на то, что якобы поддерживается - PullRequest
0 голосов
/ 22 декабря 2018

Я пытаюсь оптимизировать эту функцию, она использует jitclass, который в настоящее время компилируется без проблем.После запуска этой функции выдается сообщение об ошибке, говорящее о том, что чтение строки x = np.digitize(x1, bins) не удалось скомпилировать из-за того, что np.digitize не поддерживается numba jit.

def eps_q_learning(env, episodes=500, eps=.5, lr=.8, y=.95, decay_factor=.999):
    q_table = np.zeros((26, 2, 2))
    bins = np.array([-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    for i in range(episodes):
        x1, news, done = env.reset()[0:3]
        x = np.digitize(x1, bins)
        eps *= decay_factor
        if np.random.random() < eps or np.sum(q_table[x, int(news), :]) == 0:
            a = np.random.randint(0, 2)
        else:
            a = np.argmax(q_table[x, int(news), :])
        _, news, done, reward = env.step(a)
        q_table[x, int(news), a] += reward + lr * (y * np.max(q_table[x, int(news), :]) - q_table[x, int(news), a])
    return q_table

Сообщение об ошибке:

    Traceback (most recent call last):
  File "<input>", line 3, in <module>
  File "C:\Users\Coen D. Needell\AppData\Local\Programs\Python\Python36\lib\site-packages\numba\dispatcher.py", line 348, in _compile_for_args
    error_rewrite(e, 'typing')
  File "C:\Users\Coen D. Needell\AppData\Local\Programs\Python\Python36\lib\site-packages\numba\dispatcher.py", line 315, in error_rewrite
    reraise(type(e), e, None)
  File "C:\Users\Coen D. Needell\AppData\Local\Programs\Python\Python36\lib\site-packages\numba\six.py", line 658, in reraise
    raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function digitize>) with argument(s) of type(s): (float64, array(int64, 1d, C))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<built-in function digitize>)
[2] During: typing of call at D:\CODE\woke-gpu\StocksGame.py (97)
File "StocksGame.py", line 97:
def eps_q_learning(env, episodes=500, eps=.5, lr=.8, y=.95, decay_factor=.999):
    <source elided>
        x1, news, done = env.reset()[0:3]
        x = np.digitize(x1, bins)
        ^
This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.
To see Python/NumPy features supported by the latest release of Numba visit:
http://numba.pydata.org/numba-doc/dev/reference/pysupported.html
and
http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html
For more information about typing errors and how to debug them visit:
http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile
If you think your code should work with Numba, please report the error message
and traceback, along with a minimal reproducer at:
https://github.com/numba/numba/issues/new

Есть идеи?Это windows 10, использующая numba версию 0.41.0, python версию 3.6.7, numpy версию 1.15.4.

...