Numba, странная модификация типа - PullRequest
1 голос
/ 28 января 2020

Я пытаюсь еще больше ускорить код, написанный на python, скомпилированный с использованием Numba. Глядя на сборку, сгенерированную numba, я заметил, что генерируются операции двойной точности, что, на мой взгляд, было странно, поскольку все входные и выходные данные должны быть float32.

Я объявляю типы переменных / массивов как float32 за пределами джитеда l oop и передайте их в функцию. Странно, но после запуска моих тестов переменная scalarout преобразуется в python float, что на самом деле является 64-битным значением.

Мой код:

from scipy import ndimage, misc
import matplotlib.pyplot as plt
import numpy.fft
from timeit import default_timer as timer
import numba
# numba.config.DUMP_ASSEMBLY = 1
from numba import float32
from numba import jit, njit, prange
from numba import cuda
import numpy as np
import scipy as sp

# import llvmlite.binding as llvm
# llvm.set_option('', '--debug-only=loop-vectorize')

@njit(fastmath=True, parallel=False)
def mydot(a, b, xlen, ylen, scalarout):
    scalarout = (np.float32)(0.0)
    for y in prange(ylen):
        for x in prange(xlen):
            scalarout += a[y, x] * b[y, x]
    return scalarout

# ======================================== TESTS ========================================

print()
xlen = 100000
ylen = 16
a = np.random.rand(ylen, xlen).astype(np.float32)
b = np.random.rand(ylen, xlen).astype(np.float32)
print("a type = ", type(a[1,1]))
scalarout = (np.float32)(0.0)
print("scalarout type, before execution = ", type(scalarout))
iters=1000

time = 100.0
for n in range(iters):
    start = timer()
    scalarout = mydot(a, b, xlen, ylen, scalarout)
    end = timer()
    if(end-start < time):
        time = end-start
print("Numba njit function time, in us = %16.10f" % ((end-start)*10**6))
print("function output = %f" % scalarout)
print("scalarout type, after execution = ", type(scalarout))

1 Ответ

1 голос
/ 28 января 2020

Это скорее расширенный комментарий, чем ответ. Если вы измените scalarout на массив float32 длины 1 и измените его, вы получите float32.

@njit(fastmath=True, parallel=False)
def mydot(a, b, xlen, ylen):
    scalarout = np.array([0.0], dtype=np.float32)
    for y in prange(ylen):
        for x in prange(xlen):
            scalarout[0] += a[y, x] * b[y, x]
    return scalarout

Если вы измените return scalarout на return scalarout[0], то выход снова будет python float.

В исходном коде для mydot результат будет python float, даже если вы напишите return np.float32(scalarout).

...