Когда я запускаю следующий код, следующая ошибка прерывает процесс обучения.
ValueError: Нет данных для "embedding_15_input". Нужны данные для каждого ключа: ['embedding_15_input']
Я хотел бы отметить, что я хочу построить сеть lstm с выходом multi_lable (11 меток).
Вот функция для создания структуры модели:
def lstm_twiter (n_input, n_out, input_dim, units_activation = 'tanh', batch_size = 20):
model = Sequential()
embedding_size_out = min(50, input_dim/2)
model.add(Embedding( input_length = n_input, output_dim = embedding_size_out\
, input_dim = input_dim, mask_zero = True))
model.add(Bidirectional(LSTM(100,activation=units_activation)))
model.add(Dropout(0.5))
model.add(Dense(n_out,activation="sigmoid"))
callsback = EarlyStopping(patience =2 )
dict_1={'callbacks':[callsback],'batch_size':batch_size}
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])
return(model, dict_1)
Вот как я это называю:
matrix_input_train, matrix_output_train, matrix_input_dev,\
matrix_output_dev, matrix_input_test, matrix_output_test,size_of_vocab= \
preprocessing (txt_file_train, txt_file_dv)
n_input = matrix_input_train.shape[1]
input_dim = size_of_vocab
n_out = matrix_output_train.shape[1]
model, dict_1=lstm_twiter(n_input, n_out, input_dim,units_activation = 'tanh'\
, batch_size =20 )
dict_1.update(x=matrix_input_train,y=matrix_output_train,epochs=10, \
validation_data=(matrix_input_dev, matrix_output_dev))
model.fit(dict_1)
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding_17 (Embedding) (None, 56, 50) 1103150
_________________________________________________________________
bidirectional_15 (Bidirectio (None, 200) 120800
_________________________________________________________________
dropout_15 (Dropout) (None, 200) 0
_________________________________________________________________
dense_14 (Dense) (None, 11) 2211
=================================================================
Total params: 1,226,161
Trainable params: 1,226,161
Non-trainable params: 0
______________________________________________________