Я встраиваю текстовый документ, используя следующий код:
from tensorflow.keras import layers
from tensorflow.keras import optimizers
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
# Model constants.
MAX_FEATURES = 10000
EMBEDDING_DIM = 128
SEQUENCE_LENGTH = 1000
vectorize_layer = TextVectorization(
standardize='lower_and_strip_punctuation',
max_tokens=MAX_FEATURES,
output_mode='int',
output_sequence_length=SEQUENCE_LENGTH,
)
vectorize_layer.adapt(docs)
text_input = tf.keras.Input(shape=(1,), dtype=tf.string)
vectorized_text = vectorize_layer(text_input)
x = layers.Embedding(MAX_FEATURES+1, EMBEDDING_DIM, input_length=SEQUENCE_LENGTH)(vectorized_text)
После слоя встраивания я хотел бы добавить дополнительные функции - это один массив numpy с горячим кодированием, который содержит 11 столбцов и строку на образец.
one_hot_input = tf.keras.Input(shape=(1, 11), dtype='float32')
x = layers.Concatenate(axis=-1)([x, one_hot_input])
Я пробовал различные комбинации для оси конкатенации, а также для формы, но безуспешно.
ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 1000, 128), (None, 1, 11)]
Модель будет компилироваться при настройке следующие параметры:
one_hot_input = tf.keras.Input(shape=(None, 11), dtype='float32')
x = layers.Concatenate(axis=-1)([x, one_hot_input])
Но при запуске выдает ошибку fit()
:
ValueError: Shape must be rank 3 but is rank 2 for '{{node functional_30/concatenate_36/concat}} = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32](functional_30/dropout_54/dropout/Mul_1, functional_30/Cast, functional_30/concatenate_36/concat/axis)' with input shapes: [?,1000,128], [?,11], [].
EDIT:
Вот сводка компилируемой, но не работающей модели:
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_68 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
text_vectorization_2 (TextVecto (None, 1000) 0 input_68[0][0]
__________________________________________________________________________________________________
embedding_56 (Embedding) (None, 1000, 128) 1280128 text_vectorization_2[20][0]
__________________________________________________________________________________________________
dropout_64 (Dropout) (None, 1000, 128) 0 embedding_56[0][0]
__________________________________________________________________________________________________
input_69 (InputLayer) [(None, None, 11)] 0
__________________________________________________________________________________________________
concatenate_44 (Concatenate) (None, 1000, 139) 0 dropout_64[0][0]
input_69[0][0]
__________________________________________________________________________________________________
dense_23 (Dense) (None, 1000, 139) 19460 concatenate_44[0][0]
__________________________________________________________________________________________________
dropout_65 (Dropout) (None, 1000, 139) 0 dense_23[0][0]
__________________________________________________________________________________________________
predictions (Dense) (None, 1000, 1) 140 dropout_65[0][0]
==================================================================================================
Total params: 1,299,728
Trainable params: 1,299,728
Non-trainable params: 0