Я использую Numba для ускорения обработки некоторого кода. Работа легко распараллеливается, и я даю попытку numba.prange. Я ожидал бы почти линейного масштабирования с количеством потоков (по крайней мере, до тех пор, пока не будет достигнут пропускной способности памяти), но я почти не увеличиваю масштаб.
Я просто пишу массив Numpy срезами, с каждым потоком, работающим над своим срезом:
@numba.njit
def do_work(i_row_begin, i_row_end, out):
a = np.empty(5)
for i_row in range(i_row_begin, i_row_end):
a[0] = i_row
for index_s in range(out.shape[1]):
a[1] = index_s
for index_t in range(out.shape[2]):
a[2] = index_t
a[3] = index_t / (1.2 + i_row)
a[4] = index_t / (1.8 + i_row)
out[i_row, index_s, index_t] = np.sum(a / (1 + np.sum(a)))
@numba.njit(parallel=True)
def do_work_parallel(num_threads, num_rows):
out = np.empty((num_rows, 3, 300))
# calculate threads
num_rows_per_thread = int(math.ceil(num_rows / num_threads))
for index_thread in numba.prange(num_threads):
# Loop over loan parts
i_row_begin = index_thread * num_rows_per_thread
i_row_end = min(num_rows, (index_thread + 1) * num_rows_per_thread)
do_work(i_row_begin, i_row_end, out)
return out
И это основной сценарий:
n_rows = 10000
def run(num_threads):
return do_work_parallel(num_threads, n_rows)
# Execute function once to compile numba functions
_ = run(2)
for num_threads in [1, 2, 3, 4]:
print("Num threads", num_threads)
now = time.time()
_ = run(num_threads)
diff = time.time() - now
print("Time elapsed: {:.3e}".format(diff))
print("Speed: {:.3e} rows/s/core".format(n_rows/(diff * num_threads)))
print("")
print("DONE")
Вывод, который я получаю, следующий:
Num threads: 1
Time elapsed: 1.078e+00
Speed: 9.275e+03 rows/s/core
Num threads: 2
Time elapsed: 1.484e+00
Speed: 3.368e+03 rows/s/core
Num threads: 3
Time elapsed: 1.442e+00
Speed: 2.311e+03 rows/s/core
Num threads: 4
Time elapsed: 1.469e+00
Speed: 1.702e+03 rows/s/core
Так ясно, что скорости вообще нет. На самом деле распараллеливание ухудшает производительность. Кто-нибудь может объяснить, почему это происходит?
Редактировать: Похоже, связано с этой ошибкой: https://github.com/numba/numba/issues/2699