Я реализовал текстовую мультиклассовую классификацию, используя RNN + CNN
Краткое описание модели:
def get_model():
input = tf.keras.layers.Input(shape=(max_len,))
embedding = tf.keras.layers.Embedding(vocab_size, embed_size, weights=[embedding_matrix], trainable=False)(input)
layer = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(128, return_sequences=True, dropout=0.1,
recurrent_dropout=0.1))(embedding)
layer = tf.keras.layers.Conv1D(64, kernel_size=3, padding="valid", kernel_initializer="glorot_uniform")(layer)
avg_pool = tf.keras.layers.GlobalAveragePooling1D()(layer)
max_pool = tf.keras.layers.GlobalMaxPooling1D()(layer)
layer = tf.keras.layers.concatenate([avg_pool, max_pool])
output = tf.keras.layers.Dense(len(y.value_counts()), activation="sigmoid")(layer)
model = tf.keras.Model(input, output)
model.summary()
return model
model = get_model()
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.001), metrics=['accuracy'])
model.fit(x_train,y_train,validation_data=(x_test,y_test),
epochs = 1, verbose = 2)
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_16 (InputLayer) [(None, 150)] 0
__________________________________________________________________________________________________
embedding_15 (Embedding) (None, 150, 100) 88400 input_16[0][0]
__________________________________________________________________________________________________
bidirectional_12 (Bidirectional (None, 150, 64) 25536 embedding_15[0][0]
__________________________________________________________________________________________________
conv1d_12 (Conv1D) (None, 148, 32) 6176 bidirectional_12[0][0]
__________________________________________________________________________________________________
global_average_pooling1d_12 (Gl (None, 32) 0 conv1d_12[0][0]
__________________________________________________________________________________________________
global_max_pooling1d_12 (Global (None, 32) 0 conv1d_12[0][0]
__________________________________________________________________________________________________
concatenate_12 (Concatenate) (None, 64) 0 global_average_pooling1d_12[0][0]
global_max_pooling1d_12[0][0]
__________________________________________________________________________________________________
dense_11 (Dense) (None, 102) 6630 concatenate_12[0][0]
==================================================================================================
Total params: 126,742
Trainable params: 38,342
Non-trainable params: 88,400
Проблема, с которой я столкнулся, заключается в моделировании другого результата прогнозированиямежду 1 и 2 эпохами.
Я выполнил прогноз модели после первой эпохи.Использование сигмоида в последнем слое.Итак, независимый прогноз для каждого класса.
Примечание: у меня небольшой набор данных.
После эпохи 1 - прогноз:
[0.821476 0.178482 0.082908 0.070871 0.244470 0.031154 0.035466 0.869020
0.413655 0.768583 0.281448 0.188352 0.417780 0.468368 0.535279 0.629149
0.781784 0.414644 0.218737 0.442238 0.682343 0.358461 0.450273 0.334286
0.577692 0.215712 0.169237 0.938595 0.180421 0.051505 0.440111 0.387701
0.257397 0.205229 0.941195 0.019577 0.138571 0.701121 0.568172 0.152105
0.741303 0.169439 0.035995 0.306321 0.382447 0.268078 0.687641 0.350583
0.524925 0.945273 0.714135 0.097993 0.102559 0.431982 0.803985 0.231302
0.246235 0.366514 0.566957 0.411760 0.316942 0.358484 0.102790 0.206971
0.312865 0.627695 0.293425 0.096269 0.183038 0.310816 0.106294 0.763296
0.253969 0.219500 0.601052 0.041123 0.257971 0.651815 0.211335 0.488649
0.414540 0.964665 0.758828 0.552555 0.589932 0.338783 0.445288 0.794278
0.835401 0.420212 0.514841 0.056917 0.389850 0.232653 0.209908 0.060420
0.390591 0.324862 0.881604 0.269407 0.196394 0.105344]
Это мой ожидаемый результат.
Я продолжил обучение до 128 эпох, точность модели достигла 97. Затем я выполнил прогноз.Получили следующий вывод.
[0.000258 0.000269 0.000021 0.000002 0.000009 0.000007 0.000023 0.000053
0.001453 0.000074 0.000039 0.000060 0.000050 0.000009 0.000628 0.000155
0.001590 0.000133 0.000078 0.000083 0.000039 0.000106 0.000632 0.000037
0.000021 0.000903 0.000020 0.001508 0.000322 0.000001 0.000003 0.000063
0.000002 0.000009 0.000095 0.000130 0.000085 0.000185 0.000062 0.000014
0.000113 0.000009 0.000001 0.000006 0.000001 0.000021 0.000043 0.000003
0.000273 0.026851 0.002266 0.000087 0.000055 0.000084 0.000006 0.000001
0.000119 0.000007 0.014515 0.001661 0.000006 0.001226 0.002544 0.000142
0.000108 0.000063 0.000173 0.000050 0.000012 0.000078 0.000012 0.000016
0.000028 0.000024 0.000240 0.000128 0.000004 0.000016 0.000008 0.000048
0.000045 0.000511 0.000209 0.000076 0.000031 0.000031 0.000330 0.000001
0.000090 0.000128 0.000007 0.000024 0.000032 0.000077 0.000026 0.000008
0.000379 0.000080 0.004676 0.000004 0.000351 0.000041]
Десятичные точки слишком малы.Я не могу пороговые значения с этим более низким десятичным числом.Что здесь пошло не так?