Вам нужен ввод, чтобы выбрать, какое встраивание вы используете.
Поскольку вы используете 150 слов, ваши вложения будут иметь форму (batch,150,200)
, которую невозможно каким-либо образом объединить с (batch, 56)
. Вам нужно что-то преобразовать, чтобы соответствовать форме. Я предлагаю вам попробовать слой Dense
, чтобы преобразовать 56 в 200 ...
word_input = Input((150,))
normal_input = Input((56,))
embedding = pretrained_embeddings(word_input)
normal = Dense(200)(normal_input)
#you could add some normalization here - read below
normal = Reshape((1,200))(normal)
concatenated = Concatenate(axis=1)([normal, embedding])
Я также предлагаю, поскольку вложения и ваши входные данные имеют разную природу, что вы применяете нормализацию, чтобы они стали более похожими :
embedding = BatchNormalization(center=False, scale=False)(embedding)
normal = BatchNormalization(center=False, scale=False)(normal)
Другая возможность (я не могу сказать, что лучше) - это объединить в другом измерении, преобразовав 56 вместо 150:
word_input = Input((150,))
normal_input = Input((56,))
embedding = pretrained_embeddings(word_input)
normal = Dense(150)(normal_input)
#you could add some normalization here - read below
normal = Reshape((150,1))(normal)
concatenated = Concatenate(axis=-1)([normal, embedding])
I Если вы считаете, что это больше подходит для периодических и сверточных сетей, вы добавляете новый канал вместо добавления нового шага.
Вы даже можете попробовать двойное объединение, что звучит круто: D
word_input = Input((150,))
normal_input = Input((56,))
embedding = pretrained_embeddings(word_input)
normal150 = Dense(150)(normal_input)
normal201 = Dense(201)(normal_input)
embedding = BatchNormalization(center=False, scale=False)(embedding)
normal150 = BatchNormalization(center=False, scale=False)(normal150)
normal201 = BatchNormalization(center=False, scale=False)(normal201)
normal150 = Reshape((150,1))(normal150)
normal201 = Reshape((1,201))(normal201)
concatenated = Concatenate(axis=-1)([normal150, embedding])
concatenated = Concatenate(axis= 1)([normal201, concatenated])