Каков наилучший способ вычисления строчных (или осевых) точечных продуктов с помощью jax? - PullRequest
0 голосов
/ 20 апреля 2020

У меня есть два числовых массива формы (N, M). Я хотел бы вычислить строчное произведение точек. Т.е. создайте массив формы (N,) такой, что n-я строка является точечным произведением n-й строки из каждого массива.

Мне известен метод numpy inner1d. Каков наилучший способ сделать это с Jax? у jax jax.numpy.inner, но это делает что-то еще.

Ответы [ 2 ]

1 голос
/ 20 апреля 2020

Вы можете попробовать jax. numpy .einsum . Здесь реализация с использованием numpy einsum

import numpy as np
from numpy.core.umath_tests import inner1d

arr1 = np.random.randint(0,10,[5,5])
arr2 = np.random.randint(0,10,[5,5])

arr = np.inner1d(arr1,arr2)
arr
array([ 87, 200, 229,  81,  53])
np.einsum('...i,...i->...',arr1,arr2)
array([ 87, 200, 229,  81,  53])
0 голосов
/ 28 апреля 2020

Вы можете определить свою собственную jit-скомпилированную версию inner1d в несколько строк кода jax:

import jax
@jax.jit
def inner1d(X, Y):
  return (X * Y).sum(-1)

Тестирование:

import jax.numpy as jnp
import numpy as np
from numpy.core import umath_tests


X = np.random.rand(5, 10)
Y = np.random.rand(5, 10)

print(umath_tests.inner1d(X, Y))
print(inner1d(jnp.array(X), jnp.array(Y)))
# [2.23219571 2.1013316  2.70353783 2.14094973 2.62582531]
# [2.2321959 2.1013315 2.703538  2.1409497 2.6258256]
...