Сделайте l oop с ломтиком матрицы более эффективным - PullRequest
1 голос
/ 05 января 2020

Я хочу сделать приведенный ниже код более эффективным, но я не уверен, как это сделать. Я хочу использовать только библиотеки numpy и нативные python.

iterations = 100
aggregation = 0
for i in range(iterations):
    aggregation += np.sum(np.linalg.norm(dat[dat_filter==i] - dat_points[i], axis=1))

dat - это матрица nxD. Dat_filter - это вектор длины n, содержащий идентификатор от 0 до num_iterations. Dat_points - это num_iterators x D matrix.

В основном я вычисляю расстояния между матрицей Dat, точки которой принадлежат классу, и точками этого класса

1 Ответ

1 голос
/ 05 января 2020

Нелегко векторизовать проблему, поскольку у вас есть квадратные корни частей ваших данных, которые не обязательно имеют одинаковую длину. Вы можете векторизовать его части для небольшого ускорения:

import numpy as np

# Make some data
n = 200000
d = 100
iterations = 2000

np.random.seed(42)
dat = np.random.random((n, d))
dat_filter = np.random.randint(0, n_it, size=n)
dat_points = np.random.random((n_it, d))


def slow(dat, dat_filter, dat_points, iterations):
    aggregation = 0
    for i in range(iterations):
        # Wrote linalg.norm as standard numpy operations,
        # such that numba can be used on the code as well
        aggregation += np.sum(np.sqrt(np.sum((dat[dat_filter==i] - dat_points[i])**2, axis=1)))
    return aggregation

def fast(dat, dat_filter, dat_points, iterations):
    # Rearrange the arrays such that the correct operations are done
    sort_idx = np.argsort(dat_filter)
    filtered_dat_squared_sum = np.sum((dat - dat_points[dat_filter])**2, axis=1)[sort_idx]
    # Count the number of different 'iterations'
    counts = np.unique(dat_filter, return_counts=True)[1]
    aggregation = 0 
    idx = 0 
    for c in counts:
        aggregation += np.sum(np.sqrt(filtered_dat_squared_sum[idx:idx+c]))
        idx += c
    return aggregation

время:

In [1]: %timeit slow(dat, dat_filter, dat_points, n_it)       
3.47 s ± 314 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [2]: %timeit fast(dat, dat_filter, dat_points, n_it)     
846 ms ± 81.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Использование numba с функцией slow немного ускоряет, но все же не так быстро как метод fast. Numba с функцией fast замедляет вызов на протестированных мною размерах матрицы.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...