Я пытаюсь получить выходные данные из слоя LSTM за шаг по времени и только на последнем шаге по времени (вывод шага и вектор контекста) отдельно, поэтому я обнаружил, что решение сделать это - сделать лямбдуслой, который извлекает вектор контекста из LSTM с return_sequences=True
.В последовательной модели он работал нормально, но когда я пытаюсь реализовать его в функциональном API, он внезапно перестает принимать измерения, заявляя, что все имеет ndim = 1, хотя это не так.код:
def ContextVector(x):
return x[-1][-1]
def ContextVectorOut(input_shape):
print([None, input_shape[-1]])
print((input_shape[::2]))
print(input_shape)
return list((None, input_shape[-1]))
input_layer = Input(shape=(10, 5))
LSTM_layer = LSTM(5, return_sequences=True)(input_layer)
context_layer = Lambda(ContextVector, output_shape=ContextVectorOut)(LSTM_layer)
repeat_context_layer = RepeatVector(10, name='context')(context_layer)
timed_dense = TimeDistributed(Dense(10))(LSTM_layer)
connected_dense = Dense(2)
connect_dense_context = connected_dense(repeat_context_layer)
connect_dense_time = connected_dense(timed_dense)
concat_out = concatenate([connect_dense_context, connect_dense_time])
output_dense = Dense(5)(concat_out)
model = Model(inputs = [input_layer], output = output_dense)
#model.add(LSTM(20, input_shape = (10, 5), return_sequences=True))
#model.add(Lambda(ContextVector, output_shape=ContextVectorOut))
#model.add(Dense(1))
model.summary()
Ошибка:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-74-016b4a976d40> in <module>()
10 LSTM_layer = LSTM(5, return_sequences=True)(input_layer)
11 context_layer = Lambda(ContextVector, output_shape=ContextVectorOut)(LSTM_layer)
---> 12 repeat_context_layer = RepeatVector(10, name='context')(context_layer)
13 timed_dense = TimeDistributed(Dense(10))(LSTM_layer)
14 connected_dense = Dense(2)
C:\ProgramData\Miniconda3\lib\site-packages\keras\engine\base_layer.py in __call__(self, inputs, **kwargs)
412 # Raise exceptions in case the input is not compatible
413 # with the input_spec specified in the layer constructor.
--> 414 self.assert_input_compatibility(inputs)
415
416 # Collect input shapes to build layer.
C:\ProgramData\Miniconda3\lib\site-packages\keras\engine\base_layer.py in assert_input_compatibility(self, inputs)
309 self.name + ': expected ndim=' +
310 str(spec.ndim) + ', found ndim=' +
--> 311 str(K.ndim(x)))
312 if spec.max_ndim is not None:
313 ndim = K.ndim(x)
ValueError: Input 0 is incompatible with layer context: expected ndim=2, found ndim=1