Значения прогноза RNN отличаются после первой эпохи - PullRequest
1 голос
/ 29 июня 2019

Я реализовал текстовую мультиклассовую классификацию, используя RNN + CNN

Краткое описание модели:

def get_model():
  input = tf.keras.layers.Input(shape=(max_len,))
  embedding = tf.keras.layers.Embedding(vocab_size, embed_size, weights=[embedding_matrix], trainable=False)(input)
  layer = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(128, return_sequences=True, dropout=0.1,
                                                      recurrent_dropout=0.1))(embedding)

  layer = tf.keras.layers.Conv1D(64, kernel_size=3, padding="valid", kernel_initializer="glorot_uniform")(layer)

  avg_pool = tf.keras.layers.GlobalAveragePooling1D()(layer)
  max_pool = tf.keras.layers.GlobalMaxPooling1D()(layer)

  layer = tf.keras.layers.concatenate([avg_pool, max_pool])

  output = tf.keras.layers.Dense(len(y.value_counts()), activation="sigmoid")(layer)

  model = tf.keras.Model(input, output)

  model.summary()

  return model

model = get_model()
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.001), metrics=['accuracy'])
model.fit(x_train,y_train,validation_data=(x_test,y_test), 
    epochs = 1, verbose = 2)

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_16 (InputLayer)           [(None, 150)]        0                                            
__________________________________________________________________________________________________
embedding_15 (Embedding)        (None, 150, 100)     88400       input_16[0][0]                   
__________________________________________________________________________________________________
bidirectional_12 (Bidirectional (None, 150, 64)      25536       embedding_15[0][0]               
__________________________________________________________________________________________________
conv1d_12 (Conv1D)              (None, 148, 32)      6176        bidirectional_12[0][0]           
__________________________________________________________________________________________________
global_average_pooling1d_12 (Gl (None, 32)           0           conv1d_12[0][0]                  
__________________________________________________________________________________________________
global_max_pooling1d_12 (Global (None, 32)           0           conv1d_12[0][0]                  
__________________________________________________________________________________________________
concatenate_12 (Concatenate)    (None, 64)           0           global_average_pooling1d_12[0][0]
                                                                 global_max_pooling1d_12[0][0]    
__________________________________________________________________________________________________
dense_11 (Dense)                (None, 102)          6630        concatenate_12[0][0]             
==================================================================================================
Total params: 126,742
Trainable params: 38,342
Non-trainable params: 88,400

Проблема, с которой я столкнулся, заключается в моделировании другого результата прогнозированиямежду 1 и 2 эпохами.

Я выполнил прогноз модели после первой эпохи.Использование сигмоида в последнем слое.Итак, независимый прогноз для каждого класса.

Примечание: у меня небольшой набор данных.

После эпохи 1 - прогноз:

[0.821476 0.178482 0.082908 0.070871 0.244470 0.031154 0.035466 0.869020
 0.413655 0.768583 0.281448 0.188352 0.417780 0.468368 0.535279 0.629149
 0.781784 0.414644 0.218737 0.442238 0.682343 0.358461 0.450273 0.334286
 0.577692 0.215712 0.169237 0.938595 0.180421 0.051505 0.440111 0.387701
 0.257397 0.205229 0.941195 0.019577 0.138571 0.701121 0.568172 0.152105
 0.741303 0.169439 0.035995 0.306321 0.382447 0.268078 0.687641 0.350583
 0.524925 0.945273 0.714135 0.097993 0.102559 0.431982 0.803985 0.231302
 0.246235 0.366514 0.566957 0.411760 0.316942 0.358484 0.102790 0.206971
 0.312865 0.627695 0.293425 0.096269 0.183038 0.310816 0.106294 0.763296
 0.253969 0.219500 0.601052 0.041123 0.257971 0.651815 0.211335 0.488649
 0.414540 0.964665 0.758828 0.552555 0.589932 0.338783 0.445288 0.794278
 0.835401 0.420212 0.514841 0.056917 0.389850 0.232653 0.209908 0.060420
 0.390591 0.324862 0.881604 0.269407 0.196394 0.105344]

Это мой ожидаемый результат.

Я продолжил обучение до 128 эпох, точность модели достигла 97. Затем я выполнил прогноз.Получили следующий вывод.

[0.000258 0.000269 0.000021 0.000002 0.000009 0.000007 0.000023 0.000053
 0.001453 0.000074 0.000039 0.000060 0.000050 0.000009 0.000628 0.000155
 0.001590 0.000133 0.000078 0.000083 0.000039 0.000106 0.000632 0.000037
 0.000021 0.000903 0.000020 0.001508 0.000322 0.000001 0.000003 0.000063
 0.000002 0.000009 0.000095 0.000130 0.000085 0.000185 0.000062 0.000014
 0.000113 0.000009 0.000001 0.000006 0.000001 0.000021 0.000043 0.000003
 0.000273 0.026851 0.002266 0.000087 0.000055 0.000084 0.000006 0.000001
 0.000119 0.000007 0.014515 0.001661 0.000006 0.001226 0.002544 0.000142
 0.000108 0.000063 0.000173 0.000050 0.000012 0.000078 0.000012 0.000016
 0.000028 0.000024 0.000240 0.000128 0.000004 0.000016 0.000008 0.000048
 0.000045 0.000511 0.000209 0.000076 0.000031 0.000031 0.000330 0.000001
 0.000090 0.000128 0.000007 0.000024 0.000032 0.000077 0.000026 0.000008
 0.000379 0.000080 0.004676 0.000004 0.000351 0.000041]

Десятичные точки слишком малы.Я не могу пороговые значения с этим более низким десятичным числом.Что здесь пошло не так?

1 Ответ

2 голосов
/ 30 июня 2019

Для классификации по нескольким меткам, которую вы называете «независимыми прогнозами», вы должны использовать потери binary_crossentropy:

model.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.001), metrics=['accuracy'])

Поскольку вы использовали разреженную категориальную кроссентропию, вполне вероятно, что ваши метки являются целыми числами, вам нужно закодировать их как двоичные векторы (1 для класса, 1 для некласса), чтобы это на самом деле работало.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...