Лучшие модели весов в нейронной сети в случае ранней остановки - PullRequest
0 голосов
/ 03 мая 2020

Я тренирую модель со следующим кодом

model=Sequential()
model.add(Dense(100, activation='relu',input_shape=(n_cols,)))
model.add(Dense(100, activation='relu'))
model.add(Dense(2,activation='softmax'))
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
early_stopping_monitor = EarlyStopping(patience=3)
model.fit(X_train_np,target,validation_split=0.3, epochs=100, callbacks=[early_stopping_monitor])

Это разработано для прекращения обучения, если параметр val_loss: не улучшается после 3 эпох. Результат показан ниже. У меня вопрос, остановится ли модель с весами эпохи 8 или 7. Потому что производительность ухудшилась в 8 эпохе, поэтому она остановилась. Но модель вышла вперед на 1 эпоху с плохими показателями, так как более ранняя (7 эпоха) была лучше. Нужно ли сейчас переучивать модель с 7 эпохами?

Train on 623 samples, validate on 268 samples
Epoch 1/100
623/623 [==============================] - 1s 1ms/step - loss: 4.0365 - accuracy: 0.5923 - val_loss: 1.2208 - val_accuracy: 0.6231
Epoch 2/100
623/623 [==============================] - 0s 114us/step - loss: 1.4412 - accuracy: 0.6356 - val_loss: 0.7193 - val_accuracy: 0.7015
Epoch 3/100
623/623 [==============================] - 0s 103us/step - loss: 1.4335 - accuracy: 0.6260 - val_loss: 1.3778 - val_accuracy: 0.7201
Epoch 4/100
623/623 [==============================] - 0s 106us/step - loss: 3.5732 - accuracy: 0.6324 - val_loss: 2.7310 - val_accuracy: 0.6194
Epoch 5/100
623/623 [==============================] - 0s 111us/step - loss: 1.3116 - accuracy: 0.6372 - val_loss: 0.5952 - val_accuracy: 0.7351
Epoch 6/100
623/623 [==============================] - 0s 98us/step - loss: 0.9357 - accuracy: 0.6645 - val_loss: 0.8047 - val_accuracy: 0.6828
Epoch 7/100
623/623 [==============================] - 0s 105us/step - loss: 0.7671 - accuracy: 0.6934 - val_loss: 0.9918 - val_accuracy: 0.6679
Epoch 8/100
623/623 [==============================] - 0s 126us/step - loss: 2.2968 - accuracy: 0.6629 - val_loss: 1.7789 - val_accuracy: 0.7425

Ответы [ 2 ]

0 голосов
/ 03 мая 2020

Весь код, который я поместил, находится в TensorFlow 2.0

  • Путь к файлу: Это строка, которая может иметь параметры форматирования, такие как эпоха число. Например, ниже приведен общий путь к файлу (вес. {Epoch: 02d} - {val_loss: .2f} .hdf5)
  • monitor: (обычно это 'val_loss'or' val_accuracy ')
  • режим: Должно ли оно быть минимизировать или максимизировать значение монитора (как правило,' min 'или' max ')
  • save_best_only: Если для этого параметра установлено значение true, то модель будет сохранена только для текущей эпохи, если ее значения метри c лучше, чем те, что были раньше. Однако, если для save_best_only установлено значение false, каждая модель будет сохраняться после каждой эпохи (независимо от того, была ли эта модель лучше предыдущих моделей или нет).

Код

model=Sequential()
model.add(Dense(100, activation='relu',input_shape=(n_cols,)))
model.add(Dense(100, activation='relu'))
model.add(Dense(2,activation='softmax'))
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
fname = "weights.{epoch:02d}-{val_loss:.2f}.hdf5"
checkpoint = tf.keras.callbacks.ModelCheckpoint(fname, monitor="val_loss",mode="min", save_best_only=True, verbose=1) 
model.fit(X_train_np,target,validation_split=0.3, epochs=100, callbacks=[checkpoint])
0 голосов
/ 03 мая 2020

Используйте restore_best_weights со значением monitor, установленным на целевое количество. Таким образом, лучшие веса будут восстановлены после обучения автоматически.

early_stopping_monitor = EarlyStopping(patience=3, 
                                       monitor='val_loss',  # assuming it's val_loss
                                       restore_best_weights=True )

Из документов:

restore_best_weights: восстанавливать ли веса моделей из эпохи с лучшим значением из отслеживаемых количество ('val_loss' здесь). Если False, используются веса моделей, полученные на последнем этапе обучения (по умолчанию False).

Ссылка на документацию

...