Модель RNN Keras для НЛП занимает много времени во время обучения без уменьшения потерь при проверке - PullRequest
1 голос
/ 21 июня 2019

Я построил модель RNN для распознавания сущностей. Я использовал вложение BERT, а затем обработал результаты с помощью модели RNN. Однако при обучении модели в течение 5 эпох каждая эпоха, кажется, занимает около 2 часов. Кроме того, потери при проверке, похоже, совсем не уменьшаются.

Я запускаю процесс на графическом процессоре RTX 2080. Я пытался манипулировать моделью, но не улучшил модель. У меня есть около 400000 предложений.

Это моя модель:

def build_model(max_seq_length, n_tags): 
    in_id = Input(shape=(max_seq_length,), name="input_ids")
    in_mask = Input(shape=(max_seq_length,), name="input_masks")
    in_segment = Input(shape=(max_seq_length,), name="segment_ids")

    bert_inputs = [in_id, in_mask, in_segment]   
    bert_output = BertLayer(n_fine_tune_layers=3, pooling="first")(bert_inputs)
    x = RepeatVector(max_seq_length)(bert_output)
    x = Bidirectional(LSTM(units=lstm_units, return_sequences=True,
                           recurrent_dropout=0.2, dropout=0.2))(x)
    x_rnn = Bidirectional(LSTM(units=lstm_units, return_sequences=True,
                               recurrent_dropout=0.2, dropout=0.2))(x)
    x = add([x, x_rnn])  # residual connection to the first biLSTM
    pred = TimeDistributed(Dense(n_tags, activation="softmax"))(x)

    model = Model(inputs=bert_inputs, outputs=pred)
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    model.summary()
    return model

Это сводка модели:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_ids (InputLayer)          (None, 30)           0                                            
__________________________________________________________________________________________________
input_masks (InputLayer)        (None, 30)           0                                            
__________________________________________________________________________________________________
segment_ids (InputLayer)        (None, 30)           0                                            
__________________________________________________________________________________________________
bert_layer_3 (BertLayer)        ((None, 30), 768)    110104890   input_ids[0][0]                  
                                                                 input_masks[0][0]                
                                                                 segment_ids[0][0]                
__________________________________________________________________________________________________
repeat_vector_2 (RepeatVector)  ((None, 30), 30, 768 0           bert_layer_3[0][0]               
__________________________________________________________________________________________________
bidirectional_2 (Bidirectional) ((None, 30), 30, 200 695200      repeat_vector_2[0][0]            
__________________________________________________________________________________________________
bidirectional_3 (Bidirectional) ((None, 30), 30, 200 240800      bidirectional_2[0][0]            
__________________________________________________________________________________________________
add_1 (Add)                     ((None, 30), 30, 200 0           bidirectional_2[0][0]            
                                                                 bidirectional_3[0][0]            
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib ((None, 30), 30, 3)  603         add_1[0][0]                      
==================================================================================================
Total params: 111,041,493
Trainable params: 22,790,811
Non-trainable params: 88,250,682
__________________________________________________________________________________________________

Журналы:

 32336/445607 [=>............................] - ETA: 2:12:59 - loss: 0.3469 - acc: 0.9068
 32352/445607 [=>............................] - ETA: 2:12:58 - loss: 0.3469 - acc: 0.9068
 32368/445607 [=>............................] - ETA: 2:12:58 - loss: 0.3469 - acc: 0.9068

Можете ли вы помочь мне узнать, где я иду не так?

1 Ответ

2 голосов
/ 21 июня 2019

Если вы используете Bert для встраивания, выходная форма должна быть (None, 30, 768).Но ваша модель Берта возвращает тензор (None, 768), тогда вы использовали RepeatVector для его дублирования.Я предполагаю, что вы извлекаете [CLS] вывод из Берта.Пожалуйста, извлеките правильный слой из модели Берта.

И причина, по которой трианирование занимает так много времени, заключается лишь в том, что для каждой эпохи вам нужно передавать все свои данные через огромную модель Берта, даже если вы замораживаете большинство слоев.

...