У меня есть простая сеть Keras, в которой используется настраиваемая функция активации, определенная как лямбда:
from tensorflow.keras.activations import relu
lrelu = lambda x: relu( x, alpha=0.01 )
model = Sequential
model.add(Dense( 10, activation=lrelu, input_dim=12 ))
...
Она компилируется, обучается, тестирует отлично (код опущен), и я могу сохранить это нормально, используя model.save( 'model.h5' )
. Но когда я пытаюсь загрузить его, используя loaded = tf.keras.models.load_model( 'model.h5', custom_objects={'lrelu' : lrelu})
, и, несмотря на определение lrelu
именно так, как показано выше, он жалуется:
ValueError: Unknown activation function:<lambda>
Подождите: не lambda
a python ключевое слово ? Я не собираюсь заново определять python, чтобы загрузить модель - где это закончится? Как мне это преодолеть? Что мне нужно указать в качестве моего custom_objects
?
Согласно руководству TF Keras по сохранению и загрузке с пользовательскими объектами и функциями ...
Пользовательские функции (например, потеря активации или инициализация) не нуждаются в методе get_config. Имени функции достаточно для загрузки, если она зарегистрирована как пользовательский объект.
Мне кажется, это именно то, что я сделал. Может быть, это применимо только к функциям, определенным с помощью def
, а не к лямбда-функциям?