Ожидается, что декодер Keras Encoder будет иметь 2 размера - PullRequest
0 голосов
/ 10 декабря 2018

Декодер Keras Encoder возвращает InvalidArgumentError, поскольку формы входных данных кажутся несовместимыми.

У меня есть:

  • X_numeric.shape дает (304, 2500, 4) как входные данные
  • y_numeric.shape дает (304, 40, 22) как выходные данные

Кодер-декодер Keras имеет следующий вид:

# Define an input sequence and process it. 
encoder_inputs = Input(shape=(None, 4)) 
encoder = LSTM(32, return_state=True) 
encoder_outputs, state_h, state_c = encoder(encoder_inputs)

# We discard `encoder_outputs` and only keep the states. 
encoder_states = [state_h, state_c]

# Set up the decoder, using `encoder_states` as initial state. 
decoder_inputs = Input(shape=(None, 22))
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference. 
decoder_lstm = LSTM(32, return_sequences=True, return_state=True) 
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
                                     initial_state=encoder_states) 
decoder_dense = Dense(22, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data` 
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

# Run training 
model.compile(optimizer='rmsprop', loss='categorical_crossentropy') 

### THE ERROR OCCURS IN THE `.fit()` CALL
model.fit([X_numerical, y_numerical], y_numerical,
          batch_size=4,
          epochs=1)

И я получаю следующую ошибку:

InvalidArgumentError Traceback (последний последний вызов) в () 25 model.fit ([X_numeric, y_numeric], y_numeric,26 batch_size = 4, ---> 27 эпох = 1)

InvalidArgumentError: Несовместимые фигуры: [4,40,22] против [1,22,1]
[[Узел: training_6 /RMSprop / gradients / dens_15 / add_grad / BroadcastGradientArgs = BroadcastGradientArgs [T = DT_INT32, _class = ["loc: @ training_6 / RMSprop / gradients / density_15 / add_grad / Sum"], _device = "/ job: localhost / replica: 0 / task: 0 / устройство: ЦП: 0 "] (обучение_6 / RMSprop / градиенты / плотность_15 / add_grad / форма, обучение_6 / RMSprop / градиенты / плотность_15 / add_grad / Shape_1)]]

Я пытался изменить форму y_numeric на (304, 22, 40), но не работает.Я также попытался y_numerical .squeeze() и изменил batch_size в вызове model.fit (), и все они возвращали различные ошибки.

В чем может быть причина этой ошибки размерности?

Резюмемоя модель:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_37 (InputLayer)           (None, None, 4)      0                                            
__________________________________________________________________________________________________
input_38 (InputLayer)           (None, None, 22)     0                                            
__________________________________________________________________________________________________
lstm_39 (LSTM)                  [(None, 32), (None,  4736        input_37[0][0]                   
__________________________________________________________________________________________________
lstm_40 (LSTM)                  [(None, None, 32), ( 7040        input_38[0][0]                   
                                                                 lstm_39[0][1]                    
                                                                 lstm_39[0][2]                    
__________________________________________________________________________________________________
dense_19 (Dense)                (None, None, 22)     726         lstm_40[0][0]                    
==================================================================================================
Total params: 12,502
Trainable params: 12,502
Non-trainable params: 0
__________________________________________________________________________________________________
...