Вычислить сумму векторов в массиве numpy на основе значений словаря - PullRequest
0 голосов
/ 28 мая 2018

У меня есть массив, подобный следующему, но гораздо больше:

array = np.random.randint(6, size=(5, 4))

array([[4, 3, 0, 2],
   [1, 4, 3, 1],
   [0, 3, 5, 2],
   [1, 0, 5, 3],
   [0, 5, 4, 4]])

У меня также есть словарь, который дает мне векторное представление каждого значения в этом массиве:

dict_ = {2:np.array([3.4, 2.6, -1.2]), 0:np.array([0, 0, 0]), 1:np.array([3.9, 2.6, -1.2]), 3:np.array([3.8, 6.6, -1.9]), 4:np.array([5.4, 2.6, -1.2]),5:np.array([6.4, 2.6, -1.2])}

Я хочу вычислить среднее значение векторных представлений для каждой строки в массиве, но когда значение равно 0, игнорируйте его при расчете среднего (словарь показывает его как вектор 0).

Например, для первой строки следует усреднить [5.4, 2.6, -1.2], [3.8, 6.6, -1.9] и [3.4, 2.6, -1.2] и дать [4.2,3.93, -1.43] в качестве первой строки выходных данных.

Я хочу вывод, который сохраняет ту же структуру строк и имеет 3 столбца (каждый вектор в словаре имеет 3 значения).

Как это можно сделать эффективным способом?В моем фактическом словаре более 100000 записей, а массив - 100000 на 5000.

Ответы [ 2 ]

0 голосов
/ 28 мая 2018

Для эффективности я бы преобразовал dict в массив, а затем использовал бы расширенное индексирование для поиска:

>>> import numpy as np
>>> 
# create problem
>>> v = np.random.random((100_000, 3))
>>> dict_ = dict(enumerate(v))
>>> arr = np.random.randint(0, 100_000, (100_000, 100))
>>> 
# solve
>>> from operator import itemgetter
>>> lookup = np.array(itemgetter(*range(100_000))(dict_))
>>> lookup[0] = np.nan
>>> result = np.nanmean(lookup[arr], axis=1)

Или применимо к примеру OP:

>>> arr = np.array([[4, 3, 0, 2],
...    [1, 4, 3, 1],
...    [0, 3, 5, 2],
...    [1, 0, 5, 3],
...    [0, 5, 4, 4]])
>>> dict_ = {2:np.array([3.4, 2.6, -1.2]), 0:np.array([0, 0, 0]), 1:np.array([3.9, 2.6, -1.2]), 3:np.array([3.8, 6.6, -1.9]), 4:np.array([5.4, 2.6, -1.2]),5:np.array([6.4, 2.6, -1.2])}
>>> 
>>> lookup = np.array(itemgetter(*range(6))(dict_))
>>> lookup[0] = np.nan
>>> result = np.nanmean(lookup[arr], axis=1)
>>> result
array([[ 4.2       ,  3.93333333, -1.43333333],
       [ 4.25      ,  3.6       , -1.375     ],
       [ 4.53333333,  3.93333333, -1.43333333],
       [ 4.7       ,  3.93333333, -1.43333333],
       [ 5.73333333,  2.6       , -1.2       ]])

Временные сопоставления с методом @ jpp:

pp:    0.8046 seconds
jpp:  10.3449 seconds
results equal: True

Код для получения времени:

import numpy as np

# create problem
v = np.random.random((100_000, 3))
dict_ = dict(enumerate(v))
arr = np.random.randint(0, 100_000, (100_000, 100))

# solve
from operator import itemgetter
def f_pp(arr, dict_):
    lookup = np.array(itemgetter(*range(100_000))(dict_))
    lookup[0] = np.nan
    return np.nanmean(lookup[arr], axis=1)

def f_jpp(arr, dict_):
    def averager(x):
        lst = [dict_[i] for i in x if i]
        return np.mean(lst, axis=0) if lst else np.array([0, 0, 0])

    return np.apply_along_axis(averager, -1, arr)


from time import perf_counter
t = perf_counter()
r_pp = f_pp(arr, dict_)
s = perf_counter()
print(f'pp:  {s-t:8.4f} seconds')
t = perf_counter()
r_jpp = f_jpp(arr, dict_)
s = perf_counter()
print(f'jpp: {s-t:8.4f} seconds')
print('results equal:', np.allclose(r_pp, r_jpp))
0 голосов
/ 28 мая 2018

Это одно из решений, использующее numpy.apply_along_axis.

. Вы должны протестировать и протестировать, чтобы убедиться, что производительность соответствует вашему варианту использования.

A = np.random.randint(6, size=(5, 4))

print(A)

[[3 5 2 4]
 [2 4 5 2]
 [0 3 1 1]
 [3 4 4 5]
 [2 5 0 2]]

zeros = {k for k, v in dict_.items() if (v==0).all()}

def averager(x):
    lst = [dict_[i] for i in x if i not in zeros]
    return np.mean(lst, axis=0) if lst else np.array([0, 0, 0])

res = np.apply_along_axis(averager, -1, A)

array([[ 4.75      ,  3.6       , -1.375     ],
       [ 4.65      ,  2.6       , -1.2       ],
       [ 3.86666667,  3.93333333, -1.43333333],
       [ 5.25      ,  3.6       , -1.375     ],
       [ 4.4       ,  2.6       , -1.2       ]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...