У меня есть решение для случая, когда указан размер оси time_steps
(т.е. не None
).Мы можем легко использовать K.repeat_elements
и K.tile
для формирования тензоров для декартового произведения:
from keras import layers, models
from keras import backend as K
def some_func(a, b):
# define the some_func here
return a + b
def cart_prod(x):
shp = K.int_shape(x)[1]
x_rep = K.repeat_elements(x, shp, axis=1)
x_tile = K.tile(x, [1, shp, 1])
res = some_func(x_rep, x_tile)
return K.reshape(res, [-1, shp, shp, K.shape(res)[-1]])
inp = layers.Input((3, 2))
out = layers.Lambda(cart_prod)(inp)
model = models.Model(inp, out)
model.predict(np.arange(6).reshape(1, 3, 2))
Выход:
array([[[[ 0., 2.],
[ 2., 4.],
[ 4., 6.]],
[[ 2., 4.],
[ 4., 6.],
[ 6., 8.]],
[[ 4., 6.],
[ 6., 8.],
[ 8., 10.]]]], dtype=float32)