Спасибо @ddoGas за указание на метод model.get_weights()
, который возвращает список весов, которые затем можно сериализовать. Просто некоторый контекст для того, почему я не сохраняю модель обычным способом: мы работаем с классами-оболочками модели, которые связывают модель и пользовательское поведение. Например, до того, как произойдет предсказание, требуется специальная проверка:
class CNN:
...
def predict():
self.do_special_validation()
self.model.predict()
Следовательно, мы сериализуем класс CNN
, а не только базовую модель. Это решение для засолки всего объекта. (pickle(CNN())
терпит неудачу, в противном случае мы бы просто использовали это)
import pickle
def serialize(cnn):
return pickle.dumps({
"weights": cnn.model.get_weights(),
"cnnclass": cnn.__class__
})
def deserialize(cnn_bytes):
loaded = pickle.loads(cnn_bytes)
weights, cnnclass = loaded['weights'], loaded['cnnclass']
cnninstance = cnnclass()
cnninstance.model.set_weights(weights)
return cnninstance
Работает хорошо, спасибо!
PS note с использованием cnn.__class__
, потому что не хотим обязательно связывать это с класс CNN
напрямую, но в целом он работает для любого класса, имеющего атрибут cnn.model
.