Обратный вызов Tensorflow: как сохранить лучшую модель в памяти, а не на диске - PullRequest
0 голосов
/ 06 мая 2020

Я использую Tensorflow для регрессии, используя следующую функцию

import tensorflow as tf

def ff(*args, **kwargs):
    model = tf.keras.models.Sequential()
    model.add(tf.keras.Input(shape=[inp_train.shape[-1],]))
    for i in range(n_layer):
        model.add(tf.keras.layers.Dense(n_unit, activation=act))
    model.add(tf.keras.layers.Dense(out_train.shape[1]))
    model.compile(optimizer=opt, loss='mae')
    early_stop  = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=100)
    check_point = tf.keras.callbacks.ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True)
    model.fit(inp_train, out_train, epochs=n_epoch, batch_size=s_batch, validation_data=(inp_val, out_val), callbacks=[early_stop, check_point], verbose=0)
    best_model = tf.keras.models.load_model('best_model.h5')
    return model, best_mode

Как видите, я сохраняю лучшую модель с помощью обратного вызова check_point и использую ее позже для прогнозирования. Проблема в том, что таким образом я должен сначала сохранить лучшую модель на диск, а затем загрузить ее с диска. Если я хочу выполнить пару запусков параллельно, поскольку каждый запуск создает файл с тем же именем, это не работает.

Итак, как я могу назначить лучшую модель в переменной, не сохраняя ее на диске?

Ответы [ 3 ]

1 голос
/ 28 августа 2020

Мне пришлось сделать это для себя, и я подумал, что поделюсь:

Обратный вызов:

class SaveBestModel(tf.keras.callbacks.Callback):
    def __init__(self, save_best_metric='val_loss', this_max=False):
        self.save_best_metric = save_best_metric
        self.max = this_max
        if this_max:
            self.best = float('-inf')
        else:
            self.best = float('inf')

    def on_epoch_end(self, epoch, logs=None):
        metric_value = logs[self.save_best_metric]
        if self.max:
            if metric_value > self.best:
                self.best = metric_value
                self.best_model = self.model

        else:
            if metric_value < self.best:
                self.best = metric_value
                self.best_model = self.model

Использование:

save_best_model = SaveBestModel()
model.fit(data, callbacks=[save_best_model]
best_model = save_best_model.best_model
0 голосов
/ 06 мая 2020

Вот базовый c пример создания обратного вызова и сохранения модели во время обратного вызова на внешний list. Это должен быть список (или тип, допускающий изменение с помощью метода). Базовый класс tf.keras.callbacks.Callback расширяется дополнительным аргументом, списком, в методе класса обратного вызова __init___. Этот пример показывает, что это работает. Когда обратный вызов вызывается на training_end, он добавляет текущую модель в список.

import tensorflow as tf
from tensorflow.python.keras.models import Model

# define a custom callback
class MyCustomCallback(tf.keras.callbacks.Callback):

  def __init__(self, external_list):
      self.list_obj = external_list

  def on_train_end(self, logs=None):
      self.list_obj.append(self.model)

# test the idea works
model_save_list = []
my_callback = MyCustomCallback(model_save_list)

model1 = Model()
my_callback.set_model(model1)
my_callback.on_train_end()

print(model_save_list)

Запустите это, и вы увидите, что внутренняя модель добавляется к вашему объекту списка:

[<tensorflow.python.keras.engine.training.Model object at 0x10d230b50>]

Измените свое обучение, добавив новый обратный вызов к обратным вызовам, например:

model.fit(inp_train, out_train, epochs=n_epoch, batch_size=s_batch, validation_data=(inp_val, out_val), callbacks=[early_stop, my_callback], verbose=0)
0 голосов
/ 06 мая 2020

Приведенный ниже код может сохранить вашу модель и загрузить ее позже ...

 import pickle
 filename = 'finalized_model.sav'
 pickle.dump(model, open(filename, 'wb'))    
 loaded_model = pickle.load(open(filename, 'rb'))

Ниже приведен полный код ...

import tensorflow as to
import pickle

def ff(*args, **kwargs):
    model = tf.keras.models.Sequential()
    model.add(tf.keras.Input(shape=[inp_train.shape[-1],]))
    for i in range(n_layer):
        model.add(tf.keras.layers.Dense(n_unit, activation=act))
    model.add(tf.keras.layers.Dense(out_train.shape[1]))
    model.compile(optimizer=opt, loss='mae')
    early_stop  = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=100)
    check_point = tf.keras.callbacks.ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True)
    model.fit(inp_train, out_train, epochs=n_epoch, batch_size=s_batch, validation_data=(inp_val, out_val), callbacks=[early_stop, check_point], verbose=0)
    best_model = tf.keras.models.load_model('best_model.h5')
    filename = 'finalized_model.sav'
    pickle.dump(best_model, open(filename, 'wb'))

    loaded_model = pickle.load(open(filename, 'rb'))
    return loaded_model
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...