Сохранить историю model.fit для разных эпох - PullRequest
0 голосов
/ 17 марта 2020

Я тренировал свою модель с эпохой = 10. Я снова переучился с эпохой = 3. и снова эпоха 5. так что каждый раз, когда я тренирую модель с эпохой = 10, 3, 5. Я хочу объединить историю всех 3. Для примера, пусть h1 = история model.fit для эпохи = 10, h2 = история model.fit для эпохи = 3, h3 = история model.fit для эпохи = 5.

Теперь в переменной h я хочу h1 + h2 + h3. Вся история будет добавлена ​​к одной переменной, чтобы я мог построить несколько графиков.

код есть,

start_time = time.time()

model.fit(x=X_train, y=y_train, batch_size=32, epochs=10, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])

end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")


start_time = time.time()

model.fit(x=X_train, y=y_train, batch_size=32, epochs=3, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])

end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")

start_time = time.time()

model.fit(x=X_train, y=y_train, batch_size=32, epochs=5, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])

end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")

Ответы [ 2 ]

1 голос
/ 18 марта 2020

Вы можете достичь этой функциональности, создав класс, который подклассы tf.keras.callbacks.Callback и использовать объект этого класса в качестве обратного вызова для model.fit.

import csv
import tensorflow.keras.backend as K
from tensorflow import keras
import os

model_directory='./xyz' # directory to save model history after every epoch 

class StoreModelHistory(keras.callbacks.Callback):

  def on_epoch_end(self,batch,logs=None):
    if ('lr' not in logs.keys()):
      logs.setdefault('lr',0)
      logs['lr'] = K.get_value(self.model.optimizer.lr)

    if not ('model_history.csv' in os.listdir(model_directory)):
      with open(model_directory+'model_history.csv','a') as f:
        y=csv.DictWriter(f,logs.keys())
        y.writeheader()

    with open(model_directory+'model_history.csv','a') as f:
      y=csv.DictWriter(f,logs.keys())
      y.writerow(logs)


model.fit(...,callbacks=[StoreModelHistory()])

Затем вы можете загрузить файл CSV и потеря модели графика, скорость обучения, метрики и т. д. c.

import pandas as pd
import matplotlib.pyplot as plt

EPOCH = 10 # number of epochs the model has trained for

history_dataframe = pd.read_csv(model_directory+'model_history.csv',sep=',')


# Plot training & validation loss values
plt.style.use("ggplot")
plt.plot(range(1,EPOCH+1),
         history_dataframe['loss'])
plt.plot(range(1,EPOCH+1),
         history_dataframe['val_loss'],
         linestyle='--')
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
0 голосов
/ 18 марта 2020

Каждый раз, когда вы вызываете model.fit(), он возвращает keras.callbacks.History объект, атрибут которого history содержит словарь . Ключи к словарю: потеря для обучения, val_loss для потери проверки и любые другие метрики , которые вы могли установить при компиляции.

Следовательно, в вашем случае вы можете сделать:

hist1 = model.fit(...)

# other code lines

hist2 = model.fit(...)

# other code lines

hist3 = model.fit(...)

# create an empty dict to save all three history dicts into
total_history_dict = dict()

for some_key in hist1.keys():
    current_values = [] # to save values from all three hist dicts
    for hist_dict in [hist1.history, hist2.history, hist3.history]:
        current_values += hist_dict[some_key]
    total_history_dict[some_key] = current_values

Теперь total_history_dict - это словарь, ключи которого, как обычно, loss , val_loss , другие метрики и списки значений, показывающие потери / метрики для каждой эпохи. (Длина списка будет суммой количества эпох во всех трех вызовах model.fit )

Теперь вы можете использовать словарь для построения графиков, используя matplotlib или сохранить его в pandas датафрейм и др. c ...

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...