Самый быстрый способ удалить / извлечь подматрицу из простой матрицы - PullRequest
0 голосов
/ 11 декабря 2018

У меня есть квадратная матрица NxN (обычно N> 500).Он построен с использованием массива numpy.

Мне нужно извлечь новую матрицу, в которой i-й столбец и строка удалены из этой матрицы.Новая матрица (N-1) x (N-1).

В настоящее время я использую следующий код для извлечения этой матрицы:

            new_mat = np.delete(old_mat,idx_2_remove,0)
            new_mat = np.delete(old_mat,idx_2_remove,1)

Я также пытался использовать:

row_indices = [i for i in range(0,idx_2_remove)]
row_indices += [i for i in range(idx_2_remove+1,N)]
col_indices = row_indices
rows = [i for i in row_indices for j in col_indices]
cols = [j for i in row_indices for j in col_indices]

old_mat[(rows, cols)].reshape(len(row_indices), len(col_indices))

Но я обнаружил, что это медленнее, чем использование np.delete() в первом.Первый все еще довольно медленный для моего приложения.

Есть ли более быстрый способ выполнить то, что я хочу?

Редактировать 1: Кажется, что следующее даже быстрее, чем два выше, но ненамного:

new_mat = old_mat[row_indices,:][:,col_indices]

1 Ответ

0 голосов
/ 11 декабря 2018

Вот 3 варианта, которые я быстро написал:

Повтор delete:

def foo1(arr, i):
    return np.delete(np.delete(arr, i, axis=0), i, axis=1)

Максимальное использование нарезки (может потребоваться несколько проверок края):

def foo2(arr,i):
    N = arr.shape[0]
    res = np.empty((N-1,N-1), arr.dtype)
    res[:i, :i] = arr[:i, :i]
    res[:i, i:] = arr[:i, i+1:]
    res[i:, :i] = arr[i+1:, :i]
    res[i:, i:] = arr[i+1:, i+1:]
    return res

Расширенное индексирование:

def foo3(arr,i):
    N = arr.shape[0]
    idx = np.r_[:i,i+1:N]
    return arr[np.ix_(idx, idx)]

Проверьте, работают ли они:

In [874]: x = np.arange(100).reshape(10,10)
In [875]: np.allclose(foo1(x,5),foo2(x,5))
Out[875]: True
In [876]: np.allclose(foo1(x,5),foo3(x,5))
Out[876]: True

Сравните время:

In [881]: timeit foo1(arr,100).shape
4.98 ms ± 190 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [882]: timeit foo2(arr,100).shape
526 µs ± 1.57 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [883]: timeit foo3(arr,100).shape
2.21 ms ± 112 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Таким образом, нарезка выполняется быстрее, даже если коддлиннееПохоже, np.delete работает как foo3, но по одному измерению за раз.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...