Вы можете видеть документ о tf.gather()
:
Создает выходной тензор с формой params.shape [: axis] + indices.shape + params.форма [ось + 1:]
Ваша форма параметров (6040,3706)
и форма индексов (?,1)
.
Таким образом, форма вывода params.shape[:0] + indices.shape + params.shape[1:]
= () + (?,1) + (3706,)
при установке axis=0
.
И форма вывода params.shape[:1] + indices.shape + params.shape[2:]
= (6040,) + (?,1) + ()
при установке axis=1
.
Вы можете использовать tf.transpose()
для перестановки осей.
import tensorflow as tf
import keras.backend as K
from keras.layers import Input,Lambda
import numpy as np
user_item_matrix = K.constant(np.zeros(shape=(6040,3706)))
# Input variables
user_input = Input(shape=(1,), dtype='int32', name='user_input')
item_input = Input(shape=(1,), dtype='int32', name='item_input')
# Embedding layer
user_rating = Lambda(lambda x: tf.gather(user_item_matrix, tf.to_int32(x), axis=0))(K.squeeze(user_input,axis=1))
item_rating = Lambda(lambda x: tf.transpose(tf.gather(user_item_matrix, tf.to_int32(x), axis=1),(1,0)))(K.squeeze(item_input,axis=1))
print(user_rating.shape)
print(item_rating.shape)
# print
(?, 3706)
(?, 6040)