Сохранить вес модели keras непосредственно в байтах / памяти? - PullRequest
0 голосов
/ 06 марта 2020

Keras позволяет сохранять целые модели или только веса моделей (см. thread ). При сохранении весов они должны быть сохранены в файл, например:

model = keras_model()
model.save_weights('/tmp/model.h5')

Вместо записи в файл, я хотел бы просто сохранить байты в памяти. Что-то вроде

model.dump_weights()

Tensorflow, похоже, не имеет этого, поэтому в качестве обходного пути я пишу на диск, а затем читаю в память:

temp = '/tmp/weights.h5'
model.save_weights(temp)
with open(temp, 'rb') as f:
    weightbytes = f.read()

Любой способ избежать этого круговое движение?

Ответы [ 2 ]

1 голос
/ 07 марта 2020

weights = model.get_weights () получит веса модели. model.set_weights (weights) установит вес модели. Одной из проблем является то, КОГДА вы сохраняете вес модели. Как правило, вы хотите сохранить веса моделей для эпохи, в которой у вас были наименьшие потери при проверке. Обратный вызов Keras ModelCheckpoint сохранит веса с наименьшими потерями при проверке в файл. Я обнаружил, что сохранение в файл неудобно, поэтому я написал небольшой пользовательский обратный вызов, чтобы просто сохранить вес с наименьшими потерями при проверке в переменной класса, а затем после завершения обучения загрузить эти веса в модель, чтобы делать прогнозы. Код показан ниже. Просто добавьте save_best_weights в список обратных вызовов при компиляции модели.

class save_best_weights(tf.keras.callbacks.Callback):
best_weights=model.get_weights()    
def __init__(self):
    super(save_best_weights, self).__init__()
    self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
    current_loss = logs.get('val_loss')
    accuracy=logs.get('val_accuracy')* 100
    if np.less(current_loss, self.best):
        self.best = current_loss            
        save_best_weights.best_weights=model.get_weights()
        print('\nSaving weights validation loss= {0:6.4f}  validation accuracy= {1:6.3f} %\n'.format(current_loss, accuracy))   

0 голосов
/ 06 марта 2020

Спасибо @ddoGas за указание на метод model.get_weights(), который возвращает список весов, которые затем можно сериализовать. Просто некоторый контекст для того, почему я не сохраняю модель обычным способом: мы работаем с классами-оболочками модели, которые связывают модель и пользовательское поведение. Например, до того, как произойдет предсказание, требуется специальная проверка:

class CNN:
   ...
   def predict():
       self.do_special_validation()
       self.model.predict()

Следовательно, мы сериализуем класс CNN, а не только базовую модель. Это решение для засолки всего объекта. (pickle(CNN()) терпит неудачу, в противном случае мы бы просто использовали это)

import pickle

def serialize(cnn):
    return pickle.dumps({
        "weights": cnn.model.get_weights(),
        "cnnclass": cnn.__class__
    })

def deserialize(cnn_bytes):
    loaded = pickle.loads(cnn_bytes)
    weights, cnnclass = loaded['weights'], loaded['cnnclass']
    cnninstance = cnnclass()
    cnninstance.model.set_weights(weights)
    return cnninstance

Работает хорошо, спасибо!

PS note с использованием cnn.__class__, потому что не хотим обязательно связывать это с класс CNN напрямую, но в целом он работает для любого класса, имеющего атрибут cnn.model.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...