Что означает «a» в строке np.einsum? - PullRequest
0 голосов
/ 14 июня 2019

Я пытаюсь преобразовать некоторый код для работы с Numba.np.einsum не поддерживается, поэтому я пытаюсь заменить его функциями, поддерживаемыми Numba.

Я частично понял, как работает np.einsum, и, например, я понял, что:

x, y, z = 3, 2, 4
A = np.arange(x * y * z).reshape(x, y, z)
B = np.arange(x * y).reshape(x, y)

C = np.einsum('ijk,kj->ki', A.T, B)

эквивалентно:

C = np.sum(A.T * B.T, axis=1).T

например, я беру ijk и трехмерные канонические индексы, но теперь у меня есть следующее выражение, которое я не могу понять:

C = np.einsum('aij,jka->ajk', A, B)

В чем смыслиндекс 'a'?Что было бы эквивалентным преобразованием с использованием умножения, суммирования и транспонирования?

1 Ответ

3 голосов
/ 14 июня 2019

Какие буквы вы используете в строке осей, не имеет большого значения (но смотрите нижнюю часть этого поста), например, мы можем поставить z для a:

>>> A = np.arange(3*4*5).reshape(3,4,5)
>>> B = np.arange(5*2*3).reshape(5,2,3)
>>> 
>>> np.einsum('aij,jka->ajk',A,B)
array([[[   0,   90],
        [ 204,  306],
        [ 456,  570],
        [ 756,  882],
        [1104, 1242]],

       [[ 110,  440],
        [ 798, 1140],
        [1534, 1888],
        [2318, 2684],
        [3150, 3528]],

       [[ 380,  950],
        [1552, 2134],
        [2772, 3366],
        [4040, 4646],
        [5356, 5974]]])
>>> np.einsum('zij,jkz->zjk',A,B)
array([[[   0,   90],
        [ 204,  306],
        [ 456,  570],
        [ 756,  882],
        [1104, 1242]],

       [[ 110,  440],
        [ 798, 1140],
        [1534, 1888],
        [2318, 2684],
        [3150, 3528]],

       [[ 380,  950],
        [1552, 2134],
        [2772, 3366],
        [4040, 4646],
        [5356, 5974]]])

Эквивалент без einsum:

>>> A.sum(1)[..., None]*B.transpose(2,0,1)
array([[[   0,   90],
        [ 204,  306],
        [ 456,  570],
        [ 756,  882],
        [1104, 1242]],

       [[ 110,  440],
        [ 798, 1140],
        [1534, 1888],
        [2318, 2684],
        [3150, 3528]],

       [[ 380,  950],
        [1552, 2134],
        [2772, 3366],
        [4040, 4646],
        [5356, 5974]]])

Идентичность индексных букв имеет значение, если выходные оси неявны, поскольку предполагается, что они в алфавитном порядке

>>> A = np.ones((2,1))
>>> np.einsum('ab', A)
array([[1.],
       [1.]])
>>> np.einsum('zb', A)
array([[1., 1.]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...