Аналог tf.depthwise_conv2d Использование Jax jax.lax.conv - PullRequest
0 голосов
/ 10 марта 2020

Я портирую код из Tensorflow в Jax и сталкиваюсь со следующей трудностью:

У меня есть два массива, R и S. У нас есть:

R.shape
(10,201,11)

и

S.shape
(61,11)

Мне нужно свернуть каждый S [:, i] с соответствующим R [j,:, i] для всех j от 0: 9, что приведет к выводу shape = [10,201]. Это можно сделать в Tensorflow, выполнив следующие действия:

R1 = tf.expand_dims(R,axis=1)
output=tf.nn.depthwise_conv2d(R1,S,strides=[1,1,1,1],padding='SAME')

Использование функции тензорного потока tf.nn.depthwise_conv2d .

Мне интересно, есть ли способ сделать это с помощью jax.lax.conv . Существует небольшое руководство по функциям свертки Jax здесь , но оно довольно трудное для понимания и, похоже, предназначено для общих 2d сверток; не 1d свертки применяются равномерно по всем рядам.

...