Когда трудно найти способ кодирования, эффективный и совместимый с numpy, но когда код с петлями for
тривиален, вы можете использовать njit
из numba
.
Лучше всего обрабатывать плоские массивы, поэтому сначала напишем функцию в numba, которая выполняет то, что вы просите, но в 1d:
from numba import njit, int64
@njit
def fast_max_flat(img_flat, imglab_flat):
n_cells =int(imglab_flat.max()) # number of cells
max_values = np.full(n_cells, - np.inf) # stores the n_cells max values seen so far
max_coords = np.zeros(n_cells, dtype=int64) # stores the corresponding coordinate
n_pixels = len(img)
for i in range(n_pixels):
label = imglab_flat[i]
value = img_flat[i]
if max_values[label] < value:
max_values[label] = value
max_coords[label] = i
return max_coords
А затем напишите оболочку Python, которая разбивает массив, применяет предыдущую функцию и получает координаты в виде списка:
def wrapper(img, imglab):
dim = img.shape
coords = fast_max_flat(img.ravel(), imglab.ravel())
return [np.unravel_index(coord, dim) for coord in coords]
На моей машине с 100 x 100 x 100
изображением из 3 ячеек это примерно в 50 раз быстрее, чем ваш метод.