Быстрая альтернатива условной установке элементов массива - PullRequest
0 голосов
/ 01 августа 2020

У меня есть два заданных трехмерных массива x_dist и y_dist, каждый из которых имеет форму (36,50,50). Элементы в x_dist и y_dist относятся к типу np.float32 и могут быть положительными, отрицательными или нулевыми. Мне нужно создать новый массив res_array, в котором я установил его значение (1-y_dist)*(x_dist) для всех индексов, кроме тех, где условие ((x_dist <= 0) | ((x_dist > 0) & (y_dist > (1 + x_dist)))) равно True. Моя текущая реализация выглядит следующим образом.

res_array  = (1-y_dist)*(x_dist)
res_array[((x_dist <= 0) | ((x_dist > 0) & (y_dist > (1 + x_dist))))] = 0.0

Однако мне нужно запускать код, содержащий этот фрагмент кода, тысячи раз, и я уверен, что есть более умный и более быстрый способ сделать то же самое. Не могли бы вы помочь мне получить код лучше с точки зрения производительности или однострочник?

Я заранее благодарен за вашу помощь!

1 Ответ

2 голосов
/ 01 августа 2020

Numba JIT можно использовать для этого эффективно. Вот реализация:

@njit
def fastImpl(x_dist, y_dist):
    res_array = np.empty(x_dist.shape)
    for z in range(res_array.shape[0]):
        for y in range(res_array.shape[1]):
            for x in range(res_array.shape[2]):
                xDist = x_dist[z,y,x]
                yDist = y_dist[z,y,x]
                if xDist > 0.0 and yDist <= (1.0 + xDist):
                    res_array[z,y,x] = (1.0 - yDist) * xDist
    return res_array

Вот результаты производительности для случайных входных матриц:

Original implementation: 494 µs ± 6.23 µs per loop (mean ± std. dev. of 7 runs, 500 loops each)
New implementation: 37.8 µs ± 236 ns per loop (mean ± std. dev. of 7 runs, 500 loops each)

Новая реализация примерно в 13 раз быстрее (без учета компиляции / прогрева время работы).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...