Избегайте гоночных условий в Нумбе - PullRequest
1 голос
/ 22 апреля 2020

Здесь представлена ​​игрушечная функция njit, которая принимает матрицу расстояний, проходит по каждой строке матрицы и записывает минимальное значение в каждом столбце, а также из какой строки получено минимальное значение. Однако, IIU C, с использованием prange, это могло бы вызвать состояние гонки (особенно для больших входных массивов):

from numba import njit, prange
import numpy as np

@njit
def some_transformation_func(D, row_i):
   """
   This function applies some transformation to the ith row (`row_i`) in the `D` matrix in place. 
   However, the transformation time is random (but all less than a second), which means 
   that the rows can take
   """

    # Apply some inplace transformation on the ith row of D


@njit(parallel=True)
def some_func(D):
    P = np.empty((D.shape[1]))
    I = np.empty((D.shape[1]), np.int64)
    P[:] = np.inf
    I[:] = -1

    for row in prange(D.shape[0]):
        some_transformation_func(D, row)
        for col in range(D.shape[1]):
            if P[col] > D[row, col]:
                P[col] = D[row, col]
                I[col] = row

    return P, I

if __name__ == "__main__":
    D = np.array([[4,1,6,9,9], 
                  [1,3,8,2,7], 
                  [2,8,0,0,1],
                  [3,7,4,6,5]
                 ])
    P, I = some_func(D)
    print(P)
    print(I)

    # [1. 1. 0. 0. 1.]
    # [1 0 2 2 2]

Как я могу подтвердить, существует ли состояние гонки ( особенно если D очень большой с большим количеством строк и столбцов)? И, что более важно, если есть состояние гонки, как я могу избежать этого?

1 Ответ

0 голосов
/ 26 апреля 2020

В этих ситуациях вместо установки prange в качестве размера массива лучше всего вручную разбить данные на количество n_threads, затем распределить обработку соответствующим образом и, наконец, выполнить сокращение. Итак, что-то вроде этого:

from numba import njit, prange, config
import numpy as np

@njit
def wrapper_func(thread_idx, start_indices, stop_indices, D, P, I):
    for row in range(start_indices[thread_idx], stop_indices[thread_idx]):
        some_transformation_func(D, row)
        for col in range(D.shape[1]):
            if P[thread_idx, col] > D[row, col]:
                P[thread_idx, col] = D[row, col]
                I[thread_idx, col] = row


@njit
def some_transformation_func(D, row_i):
   """
   This function applies some transformation to the ith row (`row_i`) in the `D` matrix in place. 
   However, the transformation time is random (but all less than a second), which means 
   that the rows can take
   """

    # Apply some inplace transformation on the ith row of D


@njit(parallel=True)
def some_func(D):
    n_threads = config.NUMBA_NUM_THREADS  # Let's assume that there are 2 threads
    P = np.empty((n_threads, D.shape[1]))
    I = np.empty((n_threads, D.shape[1]), np.int64)
    P[:, :] = np.inf
    I[:, :] = -1

    start_indices = np.array([0, 2], np.int64)
    stop_indices = np.array([2, 4], np.int64)  # Note that these are exclusive

    for thread_idx in prange(n_threads):
        wrapper_func(thread_idx, start_indices, stop_indices, D, P, I)

    # Perform reduction from all threads and store results in P[0]
    for thread_idx in range(1, n_threads):
        for i in prange(l):
            if P[0, i] > P[thread_idx, i]:
                P[0, i] = P[thread_idx, i]
                I[0, i] = I[thread_idx, i]

    return P[0], I[0]

if __name__ == "__main__":
    D = np.array([[4,1,6,9,9], 
                  [1,3,8,2,7], 
                  [2,8,0,0,1],
                  [3,7,4,6,5]
                 ])
    P, I = some_func(D)
    print(P)
    print(I)

    # [1. 1. 0. 0. 1.]
    # [1 0 2 2 2]

Обратите внимание, что это будет стоить вам больше памяти (ровно n_threads больше памяти), но вы выиграете от распараллеливания. Кроме того, код становится чище и намного проще в обслуживании. Что нужно сделать, так это придумать лучший способ разбить данные на части и определить индексы start_row и stop_row (эксклюзивные).

...