Я ни в коем случае не эксперт, но вот что я бы предложил в порядке убывания наиболее важных:
1) Добавьте больше данных, если это возможно. Больше данных - это всегда хорошо, и это помогает повысить надежность в соответствии с прогнозами вашей сети.
2) Добавьте выпадающие слои, чтобы избежать переподгонки
3) Имейте повозкус инициализаторами ядра и смещения
4) [Наиболее актуальный ответ на ваш вопрос] Сохраните тренировочные веса вашей модели и загрузите их в новую модель перед тренировкой.
5) Измените тип используемой вами архитектуры модели. Затем возьмите с собой повозку с числами эпох, разделениями проверки, формулами оценки потерь и т. Д.
Надеюсь, это поможет!
РЕДАКТИРОВАТЬ: дополнительная информация о номере 4
Таким образом, вы можете сохранять и загружать вес модели во время или после тренировки модели. См. Здесь для более подробной информации о сохранении.
В общих чертах, давайте рассмотрим основы. Я предполагаю, что вы проходите через керас, но то же самое относится к tf:
Сохранение модели после тренировки
Просто позвоните:
model_json = model.to_json()
with open("{Your_Model}.json", "w") as json_file:
json_file.write(model_json)
# serialize weights to HDF5
model.save_weights("{Your_Model}.h5")
print("Saved model to disk")
Загрузка модели
Вы можете загрузить структуру модели из json следующим образом:
from keras.models import model_from_json
json_file = open('{Your_Model.json}', 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)
И загрузить веса, если хотите:
model.load_weights('{Your_Weights}.h5', by_name=True)
Затем скомпилируйте модель, и вы готовы переучить / предсказать. by_name
для меня было важно перезагрузить весы обратно в ту же модель архитектуры;пропуск этого значения может привести к ошибке.
Проверка модели во время тренировки
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath={checkpoint_path},
save_weights_only=True,
verbose=1)
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images,test_labels),
callbacks=[cp_callback]) # Pass callback to training