Что такое тензорот этой операции 4D einsum? - PullRequest
0 голосов
/ 11 декабря 2018

Вот простой код, который «пакетно умножает» матрицу 4D a на матрицу 3D b:

from functools import reduce
import numpy as np
from operator import mul

def einsum(a, b):
    return np.einsum('ijkl,jkl->ikl', a, b)

def original(a, b):
    s0, s1, s2, s3 = a.shape
    c = np.empty((s0, s2, s3))
    for j in range(s3):
        for i in range(s2):
            c[:, j, i] = np.dot(a[:, :, j, i], b[:, j, i])
    return c

sz_a = (16, 4, 512, 512)
sz_b = (4, 512, 512)

a = np.random.random(reduce(mul, sz_a)).reshape(sz_a)
b = np.random.random(reduce(mul, sz_b)).reshape(sz_b)

Для расчета времени:

%timeit original(a, b)
395 ms ± 2.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit einsum(a, b)
23.1 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Я бы хотел проверить тензордопроизводительность, чтобы увидеть, как она сравнивается, но у меня действительно есть некоторые проблемы, заключающиеся в том, как использовать это здесь.Если кто-то достаточно знаком, чтобы вести меня с этим, это будет очень цениться.Спасибо!

Моя первоначальная мысль была:

np.tensordot(a, b, axes=((1),(0)))

Но это дает мне MemoryError, так что я не думаю, что это правильно ...

1 Ответ

0 голосов
/ 11 декабря 2018

Сравнение времени вашего einsum с matmul эквивалентом:

In [910]: timeit (a.transpose(2,3,0,1)@b[:,None].transpose(2,3,0,1)).transpose(2,3,0,1)[:
     ...: ,0]
90.5 ms ± 92.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [911]: timeit np.einsum('ijkl,jkl->ikl', a, b)
92.7 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Время достаточно близко, что я подозреваю, что einsum оптимизация фактически использует matmul.Первоначально einsum использовал свою собственную скомпилированную итерацию суммы продуктов, но в последнее время с недавними изменениями он использует различные методы, включая dot и matmul, если они подходят.

matmul был создандля обработки случая, когда начальные размеры представляют собой стек матриц.В вашей задаче последние 2 измерения - это стек, с dot, действующим на начальное.matmul был создан для работы с такими сложенными точками.dot и его производные tensordot не справляются с такими стеками.

...