Сохранение и загрузка слоев вероятности Tensorflow внутри последовательного модуля Keras - PullRequest
1 голос
/ 07 октября 2019

Я использую вероятностные слои Tensorflow внутри последовательностей Keras. Однако, сохраняя модель как json, а затем загружая ее, выдается исключение. Я использую custom_objects, чтобы иметь возможность загружать пользовательские слои. Вот минималистичный код для воспроизведения ошибки.

import tensorflow_probability as tfp

tfk = tf.keras
tfkl = tf.keras.layers
tfpl = tfp.layers

original_dim = 20
latent_dim = 2
model = tfk.Sequential([
    tfkl.InputLayer(input_shape=original_dim),
    tfkl.Dense(10, activation=tf.nn.leaky_relu),
    tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(latent_dim), activation=None),
    tfpl.MultivariateNormalTriL(latent_dim)
])

model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)



loaded_model = tfk.models.model_from_json(
    open('model.json').read(),
    custom_objects={
        'leaky_relu': tf.nn.leaky_relu, 
        'MultivariateNormalTriL': tfpl.MultivariateNormalTriL
    }
)

Я получаю следующее исключение:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-26-bbbeffd9e4be> in <module>
      3     custom_objects={
      4         'leaky_relu': tf.nn.leaky_relu,
----> 5         'MultivariateNormalTriL': tfpl.MultivariateNormalTriL
      6     }
      7 )

//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow/python/keras/saving/model_config.py in model_from_json(json_string, custom_objects)
     94   config = json.loads(json_string)
     95   from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
---> 96   return deserialize(config, custom_objects=custom_objects)

//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
     87       module_objects=globs,
     88       custom_objects=custom_objects,
---> 89       printable_module_name='layer')

//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    190             custom_objects=dict(
    191                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 192                 list(custom_objects.items())))
    193       with CustomObjectScope(custom_objects):
    194         return cls.from_config(cls_config)

//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow/python/keras/engine/sequential.py in from_config(cls, config, custom_objects)
    350     for layer_config in layer_configs:
    351       layer = layer_module.deserialize(layer_config,
--> 352                                        custom_objects=custom_objects)
    353       model.add(layer)
    354     if not model.inputs and build_input_shape:

//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
     87       module_objects=globs,
     88       custom_objects=custom_objects,
---> 89       printable_module_name='layer')

//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    190             custom_objects=dict(
    191                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 192                 list(custom_objects.items())))
    193       with CustomObjectScope(custom_objects):
    194         return cls.from_config(cls_config)

//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow_probability/python/layers/distribution_layer.py in from_config(cls, config, custom_objects)
    875             config['arguments'][key] = np.array(arg_dict['value'])
    876 
--> 877     return cls(**config)
    878 
    879   @classmethod

TypeError: __init__() missing 1 required positional argument: 'event_size'

1 Ответ

0 голосов
/ 09 октября 2019

Проверьте, работает ли следующий метод загрузки:

loaded_model = tfk.models.model_from_json(
    open('model.json').read(),
    custom_objects={
        'leaky_relu': tf.nn.leaky_relu, 
        'MultivariateNormalTriL': tfpl.MultivariateNormalTriL.params_size(latent_dim)
    }
)
...