Я реализую простую функцию умножения матриц с помощью Numba и обнаружил, что она значительно медленнее, чем NumPy. В приведенном ниже примере Numba медленнее в 40 раз. Есть ли способ еще больше ускорить Нумбу? Заранее спасибо за ваш отзыв.
import time
import numpy as np
import numba
from numba import njit, prange
@numba.jit('void(float64[:,:],float64[:,:],float64[:,:])', fastmath=True, parallel=True)
def matmul(matrix1,matrix2,rmatrix):
a = matrix1.shape[0]
b = matrix2.shape[1]
c = matrix2.shape[0]
for i in prange(a):
for j in prange(b):
for k in prange(c):
rmatrix[i,j] += matrix1[i,k] * matrix2[k,j]
M = np.random.normal(0,10,(10,10))**2
N = np.random.normal(0,10,(10,10))**2
A = np.random.normal(0,10,(10,10))**2
matmul(M,N,A) #to make sure compiled
n = 3000
M = np.random.normal(0,10,(n,1000))**2
N = np.random.normal(0,10,(1000,n))**2
A = np.zeros((3000,3000))
t = time.time()
matmul(M,N,A)
print("Numba:", time.time()-t)
t = time.time()
np.dot(np.log(M),np.log(N))
print("NumPy:", time.time()-t)