Примерка нейронной сети выглядит странно - PullRequest
0 голосов
/ 20 апреля 2020

Я пытаюсь разработать модель, которая будет прогнозировать доли компонента химического рецепта. Вам просто нужно выбрать компоненты и какие свойства вы хотите получить в окончательной смеси. Для Input1 у меня есть некоторые данные о компонентах, таких как вязкость и плотность. В качестве входа 2 у меня есть значения свойств, которые мне нужны. В качестве Output, его (1,4) тензор с 4-мя компонентами, сумма этого равна 1.

. В качестве обучающих данных у меня около 20 000 выборок, это количество не так уж и велико, поэтому я предполагаю, что его модель должен переодеться. В то время как я пытаюсь тренировать свою модель, я решил сначала надеть ее. Тем не менее, это не переоснащение такими параметрами. И максимальная точность для этой модели составляет около 50%, помогите мне улучшить точность модели, пожалуйста.

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Cov_inp (InputLayer)            (None, 2, 4)         0                                            
__________________________________________________________________________________________________
conv1d_18 (Conv1D)              (None, 2, 128)       2176        Cov_inp[0][0]                    
__________________________________________________________________________________________________
max_pooling1d_18 (MaxPooling1D) (None, 1, 128)       0           conv1d_18[0][0]                  

__________________________________________________________________________________________________

flatten_18 (Flatten)            (None, 128)          0           max_pooling1d_18[0][0]           
__________________________________________________________________________________________________
Res_inp (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
dense_158 (Dense)               (None, 64)           8256        flatten_18[0][0]                 
__________________________________________________________________________________________________
dense_161 (Dense)               (None, 64)           128         Res_inp[0][0]                    
__________________________________________________________________________________________________
dense_159 (Dense)               (None, 128)          8320        dense_158[0][0]                  
__________________________________________________________________________________________________
dense_162 (Dense)               (None, 128)          8320        dense_161[0][0]                  
__________________________________________________________________________________________________
dense_160 (Dense)               (None, 256)          33024       dense_159[0][0]                  
__________________________________________________________________________________________________
dense_163 (Dense)               (None, 256)          33024       dense_162[0][0]                  
__________________________________________________________________________________________________
dot_10 (Dot)                    (None, 1)            0           dense_160[0][0]                  
                                                                 dense_163[0][0]                  
__________________________________________________________________________________________________
dense_164 (Dense)               (None, 256)          512         dot_10[0][0]                     
__________________________________________________________________________________________________
dense_165 (Dense)               (None, 128)          32896       dense_164[0][0]                  
__________________________________________________________________________________________________
dense_166 (Dense)               (None, 64)           8256        dense_165[0][0]                  
__________________________________________________________________________________________________
dense_167 (Dense)               (None, 32)           2080        dense_166[0][0]                  
__________________________________________________________________________________________________
dense_168 (Dense)               (None, 4)            132         dense_167[0][0]                  
==================================================================================================
Total params: 137,124
Trainable params: 137,124
Non-trainable params: 0

Архитектура и потери, график точности ниже.

The model graph loss plot enter image description here

И мой код в нейронной сети

X_features = Input(shape=(2,4),name='Cov_inp')
KV_data = Input(shape=(1,),name='Res_inp')

# Branch 1
layer1 = Conv1D(filters=128,kernel_size=4,activation='relu',padding='same',input_shape=(2,4))(X_features)
pooling1=MaxPooling1D(pool_size=2)(layer1)
flattern=Flatten()(pooling1)
dense_conv=Dense(64,activation='relu')(flattern)
dense_conv=Dense(128,activation='relu')(dense_conv)
dense_conv=Dense(256,activation='relu')(dense_conv)

# Branch 2
dense1=Dense(64,activation='relu')(KV_data)
dense2=Dense(128,activation='relu')(dense1)
dense4=Dense(256,activation='relu')(dense2)

#Concatination
addition=Dot(axes=1)([dense_conv, dense4])
final_dense = Dense(256,activation='relu')(addition)
final_dense = Dense(128,activation='relu')(final_dense)
final_dense = Dense(64,activation='relu')(final_dense)
final_dense = Dense(32,activation='relu')(final_dense)
out = Dense(4,activation='sigmoid')(final_dense)

from keras import backend as K

def root_mean_squared_error(y_true, y_pred):
        return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1)) 

model = Model(inputs=[X_features, KV_data], outputs=out)
model.compile(loss=root_mean_squared_error, optimizer=adam, metrics=['accuracy'])
model.summary()
history=model.fit({'Cov_inp':dataset_X,'Res_inp':dataset_KV}, dataset_Y,validation_split=0.2, epochs=2000, batch_size=128, verbose=2)
...