Продукт Numpy array dot - «агрегирование» строк без оценки всей вещи - PullRequest
1 голос
/ 11 июля 2019

Я хочу вычислить внутреннее произведение двумерного массива с собой - то есть np.inner (A, A) - и затем для каждой строки извлечь 2-е наибольшее значение и его индекс:

import numpy as np
import heapq

A = np.random.rand(1000,1000)
prod = np.inner(A,A)
tmp = []
for i, x in enumerate(prod): 
    idx = heapq.nlargest(2, range(len(x)), key=x.__getitem__)[1]
    max_val = heapq.nlargest(2, x)[1]
    tmp.append((i, idx, val))

Однако, если A становится огромным, не представляется возможным сохранить весь продукт в памяти, когда фактически требуется только две строки за один раз. Это было бы чрезвычайно легко реализовать в C, например, но я не уверен, как это сделать в Python.

Кажется, что должен быть элегантный способ решить эту проблему с помощью numpy или scipy, но я не смог понять это.

Ответы [ 2 ]

2 голосов
/ 11 июля 2019

Мы можем использовать np.argpartition, который делает indirect partition и, таким образом, достигает некоторой эффективности там -

def nth_largest(prod): # works on prod from numpy.inner output
    idx = np.argpartition(prod,-2,axis=1)[:,-2:]
    I = np.arange(len(idx))
    idx_s = prod[I[:,None],idx].argsort(1)
    n_largest_indices = idx[I,idx_s[:,0]]
    max_vals = prod[I,n_largest_indices]
    return list(zip(I,n_largest_indices,max_vals))

Если вашей главной задачей является память, прибегните к циклу -

def innerprod_nth_largest_loopy(A, k): # works on input A
    idxs = np.empty(len(A),dtype=np.uint64)
    vals = np.empty(len(A),dtype=A.dtype)
    for i,a in enumerate(A):
        r = a.dot(A.T)
        idx = np.argpartition(r,-k)[-k:]
        idxs[i] = idx[r[idx].argsort()[0]]
        vals[i] = r[idxs[i]]
    return list(zip(range(len(A)),idxs,vals))

Заметьте, однако, что петлевая версия будет намного медленнее, просто хорошо со стороны памяти.

1 голос
/ 11 июля 2019

думаю за

prod = np.inner(A, A)

i-й ряд prod равен

prod[i, :] = np.inner(A[i, :], A)

Так могли бы вы использовать цикл for и рассчитывать только 2-е наибольшее значение для одной строки за раз?

...