Есть ли способ обучить модель CNN, сохранить веса этой CNN, а затем использовать эти веса, чтобы переобучить эту CNN для других данных поезда? - PullRequest
1 голос
/ 15 марта 2020

Кто-то рассказал мне об этом типе эксперимента. Первым шагом является обучение CNN и сохранение весов, а вторым шагом является использование этих весов для переобучения этого CNN, но на этот раз для добавления дополнительных данных в ваш набор поездов (точная настройка).

I думаю, это что-то вроде трансферного обучения, но с CNN, который вы тренируете. Есть ли способ выбрать веса перед тренировкой CNN, и эти выбранные веса должны быть в вашем файле?

Итак, я до сих пор тренировал свою модель CNN и сохранял веса в файле h5 с помощью код ниже

model.compile(loss='categorical_crossentropy', optimizer=opt,metrics=['accuracy'])
validation_data=(x_testcnn, y_test))
checkpoint_path= 'scratchmodel.best.h5'
save_dir = os.path.join(os.getcwd(), 'weights')
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                             save_weights_only=True,
                                             verbose=1)
 cnnhistory=model.fit(x_traincnn, 
      y_train,
      batch_size=16,           
      epochs=400,
      validation_data=(x_testcnn,y_test),
      callbacks=[cp_callback])

Теперь я хочу переучить тот же CNN, с теми же весами, но на этот раз с добавлением данных в набор поездов. Есть ли способ сделать это? Спасибо за помощь.

1 Ответ

0 голосов
/ 17 марта 2020

Да, вам просто нужно загрузить веса во вновь созданную модель и затем тренироваться с вашими новыми данными.

from tensorflow.python.keras.models import load_model #Tensorflow 2.0

new_model.compile(loss='categorical_crossentropy', optimizer=opt,metrics=['accuracy'])
new_model = load_model(filepath, compile=False) #compile=False allows you to load saved optimizer state

new_model.fit(...) # Fit on new data, leveraging training on old data
...