Трансляция трехмерных массивов для поэлементного умножения - PullRequest
0 голосов
/ 06 сентября 2018

Добрый вечер,

Мне нужна помощь в понимании расширенного вещания со сложными массивами.

У меня есть:

массив A: 50000x2000

массив B: 2000x10x10

Реализация с циклом for:

for k in range(50000):
    temp = A[k,:].reshape(2000,1,1)
    finalarray[k,:,:]=np.sum ( B*temp , axis=0)

Я хочу поэлементное умножение и суммирование оси с 2000 элементами с конечным продуктом:

finalarray: 50000x10x10

Можно ли избежать цикла for? Спасибо!

Ответы [ 2 ]

0 голосов
/ 06 сентября 2018

Для чего-то подобного я бы использовал np.einsum, что позволяет довольно легко записать, что вы хотите сделать, с точки зрения действий с индексами, которые вы хотите:

fast = np.einsum('ij,jkl->ikl', A, B)

, который дает мне тот же результат (отбрасывание 50000-> 500, так что циклический завершается быстро):

A = np.random.random((500, 2000))
B = np.random.random((2000, 10, 10))
finalarray = np.zeros((500, 10, 10))
for k in range(500):
    temp = A[k,:].reshape(2000,1,1)
    finalarray[k,:,:]=np.sum ( B*temp , axis=0)

fast = np.einsum('ij,jkl->ikl', A, B)

дает мне

In [81]: (finalarray == fast).all()
Out[81]: True

и разумная производительность даже в корпусе 50000:

In [88]: %time fast = np.einsum('ij,jkl->ikl', A, B)
Wall time: 4.93 s

In [89]: fast.shape

Out[89]: (50000, 10, 10)

В качестве альтернативы, в этом случае вы можете использовать tensordot:

faster = np.tensordot(A, B, axes=1)

, что будет в несколько раз быстрее (за счет менее общего):

In [29]: A = np.random.random((50000, 2000))

In [30]: B = np.random.random((2000, 10, 10))

In [31]: %time fast = np.einsum('ij,jkl->ikl', A, B)
Wall time: 5.08 s

In [32]: %time faster = np.tensordot(A, B, axes=1)
Wall time: 504 ms

In [33]: np.allclose(fast, faster)
Out[33]: True

Я должен был использовать allclose здесь, потому что значения очень немного отличаются:

In [34]: abs(fast - faster).max()
Out[34]: 2.7853275241795927e-12
0 голосов
/ 06 сентября 2018

Это должно работать:

(A[:, :, None, None] * B[None, :, :]).sum(axis=1)

Но это взорвет вашу память для промежуточного массива, созданного продуктом.

Продукт имеет форму (50000, 2000, 10, 10), поэтому содержит 10 миллиардов элементов, что составляет 80 ГБ для 64-разрядных значений с плавающей запятой.

...