трансляция умножения матриц в nd4j - PullRequest
0 голосов
/ 05 декабря 2018

В python предположим, что

a = np.array(range(0,12)).reshape(2,2,3)
b = np.array(range(0,6)).reshape(3,2)
c = np.matmul(a,b) // a @ b

У нас есть

a: array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]]])

b: array([[0, 1],
       [2, 3],
       [4, 5]])

c: array([[[10, 13],
        [28, 40]],

       [[46, 67],
        [64, 94]]])

Может ли кто-нибудь помочь мне выполнить эквивалентную операцию в java nd4j без цикла for?Я пробовал broadcast.mul, но оказывается, broadcast.mul - это поэлементное умножение.Я не нашел ни одной трансляции для mmul.

1 Ответ

0 голосов
/ 05 декабря 2018

Я понял это сам.Ответ показан ниже на случай, если кому-то это нужно.С Nd4j.tensorMmul матричное вещание может быть легко достигнуто.например,

val a = Nd4j.create(0d to 11d by 1d toArray, Array[Int](2, 2, 3))
val b = Nd4j.create(0d to 5d by 1d toArray, Array[Int](3, 2))
Nd4j.tensorMmul(a, b, Array(Array(2), Array(0))) // matrix broadcast

Это код для Scala.Для Java вам просто нужно изменить код для создания массивов.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...