Я использую декодер с вниманием, мой код работает, когда ввод (shape = (None,)) для декодера. но я не знаю почему - PullRequest
0 голосов
/ 24 февраля 2020

Здесь вы можете видеть, что я объявил Input (shape = (None,)) для декодера, но я хочу знать, как он работает во время обучения и тестирования.

emb_dim = 300

encoder_input = Input(shape=(60,))
x1=Embedding(vocab_size, 300,weights=[input_matrix],trainable=False)(encoder_input)
e_lstm_out, e_hidden_out, e_cell_out = LSTM(32,return_sequences=True,return_state=True,dropout=0.4)(x1)

decoder_input = Input(shape=(None,))

decoder_embedding_layer = Embedding(y_vocab_size, 300,trainable=True)
decoder_embedding = decoder_embedding_layer(decoder_input)

decoder_lstm = LSTM(32, return_sequences=True, return_state=True,dropout=0.4)
d_lstm_out,d_hidden_out,d_cell_out = decoder_lstm(decoder_embedding,initial_state=[e_hidden_out, e_cell_out])

attention_layer = AttentionLayer(name='attention_layer')
attention_out, attention_states = attention_layer([e_lstm_out, d_lstm_out])

# Concat attention input and decoder LSTM output
concat = Concatenate(axis=-1, name='concat_layer')([d_lstm_out, attention_out])

#dense layer
decoder_dense =  TimeDistributed(Dense(y_vocab_size, activation='softmax'))
decoder_dense_outputs = decoder_dense(concat)

# Define the model 
model = Model([encoder_input, decoder_input], decoder_dense_outputs)

model.summary()

и итоги Ниже приведено, что вы можете видеть, что input_2 (InputLayer) [(None, None)]

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 60)]         0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 60, 300)      5206500     input_1[0][0]                    
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, None, 300)    6503400     input_2[0][0]                    
__________________________________________________________________________________________________
lstm (LSTM)                     [(None, 60, 32), (No 42624       embedding[0][0]                  
__________________________________________________________________________________________________
lstm_1 (LSTM)                   [(None, None, 32), ( 42624       embedding_1[0][0]                
                                                             lstm[0][1]                       
                                                             lstm[0][2]                       
__________________________________________________________________________________________________
attention_layer (AttentionLayer ((None, None, 32), ( 2080        lstm[0][0]                       
                                                             lstm_1[0][0]                     
__________________________________________________________________________________________________
concat_layer (Concatenate)      (None, None, 64)     0           lstm_1[0][0]                     
                                                             attention_layer[0][0]            
__________________________________________________________________________________________________
time_distributed (TimeDistribut (None, None, 21678)  1409070     concat_layer[0][0]


Total params: 13,206,298
Trainable params: 7,999,798
Non-trainable params: 5,206,500
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...