Я портирую код из Tensorflow в Jax и сталкиваюсь со следующей трудностью:
У меня есть два массива, R и S. У нас есть:
R.shape
(10,201,11)
и
S.shape
(61,11)
Мне нужно свернуть каждый S [:, i] с соответствующим R [j,:, i] для всех j от 0: 9, что приведет к выводу shape = [10,201]. Это можно сделать в Tensorflow, выполнив следующие действия:
R1 = tf.expand_dims(R,axis=1)
output=tf.nn.depthwise_conv2d(R1,S,strides=[1,1,1,1],padding='SAME')
Использование функции тензорного потока tf.nn.depthwise_conv2d .
Мне интересно, есть ли способ сделать это с помощью jax.lax.conv . Существует небольшое руководство по функциям свертки Jax здесь , но оно довольно трудное для понимания и, похоже, предназначено для общих 2d сверток; не 1d свертки применяются равномерно по всем рядам.