Не удается сохранить модель keras в формате tf - PullRequest
0 голосов
/ 27 мая 2020

Я использую tenorflow-gpu == 2.0.0 Следующая модель не сохраняется:

inputs = Input(shape=(5,1,14))
emb = Input(shape=(512,6,18))

upconv1 = Conv2DTranspose(filters=16, 
                          kernel_size=(6,5), 
                          strides=(6,1),
                          data_format="channels_first")(inputs)

x = MyConcat()([upconv1, emb])

decoder = Model(inputs=[inputs, emb], outputs=x)      
decoder.save("decoder.tf", save_format="tf")

Здесь MyConcat - мой настраиваемый слой:

class MyConcat(keras.layers.Layer):
def __init__(self):
    super(MyConcat, self).__init__()
def call(self, inputs):
    x, emb = inputs
    return Concatenate(axis=1)([x, emb])

def compute_output_shape(self, input_shape):
    shape = (None, input_shape[0][1] + input_shape[1][1],
            input_shape[0][2], input_shape[0][3])
    return shape

Ошибка message:


ValueError                                Traceback (most recent call last)
<ipython-input-10-1bd6aa1b259b> in <module>
     10 
     11 decoder = Model(inputs=[inputs, emb], outputs=x)
---> 12 decoder.save("decoder.tf", save_format="tf")

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
    973     """
    974     saving.save_model(self, filepath, overwrite, include_optimizer, save_format,
--> 975                       signatures, options)
    976 
    977   def save_weights(self, filepath, overwrite=True, save_format=None):

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
    113   else:
    114     saved_model_save.save(model, filepath, overwrite, include_optimizer,
--> 115                           signatures, options)
    116 
    117 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options)
     72   # default learning phase placeholder.
     73   with K.learning_phase_scope(0):
---> 74     save_lib.save(model, filepath, signatures, options)
     75 
     76   if not include_optimizer:

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py in save(obj, export_dir, signatures, options)
    868   if signatures is None:
    869     signatures = signature_serialization.find_function_to_export(
--> 870         checkpoint_graph_view)
    871 
    872   signatures = signature_serialization.canonicalize_signatures(signatures)

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
     62   # If the user did not specify signatures, check the root object for a function
     63   # that can be made into a signature.
---> 64   functions = saveable_view.list_functions(saveable_view.root)
     65   signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
     66   if signature is not None:

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py in list_functions(self, obj)
    139     if obj_functions is None:
    140       obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
--> 141           self._serialization_cache)
    142       self._functions[obj] = obj_functions
    143     return obj_functions

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
   2420   def _list_functions_for_serialization(self, serialization_cache):
   2421     return (self._trackable_saved_model_saver
-> 2422             .list_functions_for_serialization(serialization_cache))
   2423 
   2424 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
     89         `ConcreteFunction`.
     90     """
---> 91     fns = self.functions_to_serialize(serialization_cache)
     92 
     93     # The parent AutoTrackable class saves all user-defined tf.functions, and

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
     77   def functions_to_serialize(self, serialization_cache):
     78     return (self._get_serialized_attributes(
---> 79         serialization_cache).functions_to_serialize)
     80 
     81   def _get_serialized_attributes(self, serialization_cache):

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
     92 
     93     object_dict, function_dict = self._get_serialized_attributes_internal(
---> 94         serialization_cache)
     95 
     96     serialized_attr.set_and_validate_objects(object_dict)

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
     51     objects, functions = (
     52         super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
---> 53             serialization_cache))
     54     functions['_default_save_signature'] = default_signature
     55     return objects, functions

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
    101     """Returns dictionary of serialized attributes."""
    102     objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
--> 103     functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
    104     # Attribute validator requires that the default save signature is added to
    105     # function dict, even if the value is None.

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in wrap_layer_functions(layer, serialization_cache)
    164   call_fn_with_losses = call_collection.add_function(
    165       _wrap_call_and_conditional_losses(layer),
--> 166       '{}_layer_call_and_return_conditional_losses'.format(layer.name))
    167   call_fn = call_collection.add_function(
    168       _extract_outputs_from_fn(layer, call_fn_with_losses),

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in add_function(self, call_fn, name)
    492       # Manually add traces for layers that have keyword arguments and have
    493       # a fully defined input signature.
--> 494       self.add_trace(*self._input_signature)
    495     return fn
    496 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in add_trace(self, *args, **kwargs)
    411             fn.get_concrete_function(*args, **kwargs)
    412 
--> 413         trace_with_training(True)
    414         trace_with_training(False)
    415       else:

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in trace_with_training(value, fn)
    409           utils.set_training_arg(value, self._training_arg_index, args, kwargs)
    410           with K.learning_phase_scope(value):
--> 411             fn.get_concrete_function(*args, **kwargs)
    412 
    413         trace_with_training(True)

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in get_concrete_function(self, *args, **kwargs)
    536     if not self.call_collection.tracing:
    537       self.call_collection.add_trace(*args, **kwargs)
--> 538     return super(LayerCall, self).get_concrete_function(*args, **kwargs)
    539 
    540 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
    774       if self._stateful_fn is None:
    775         initializer_map = object_identity.ObjectIdentityDictionary()
--> 776         self._initialize(args, kwargs, add_initializers_to=initializer_map)
    777         self._initialize_uninitialized_variables(initializer_map)
    778 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    406     self._concrete_stateful_fn = (
    407         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 408             *args, **kwds))
    409 
    410     def invalid_creator_scope(*unused_args, **unused_kwds):

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   1846     if self.input_signature:
   1847       args, kwargs = None, None
-> 1848     graph_function, _, _ = self._maybe_define_function(args, kwargs)
   1849     return graph_function
   1850 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2148         graph_function = self._function_cache.primary.get(cache_key, None)
   2149         if graph_function is None:
-> 2150           graph_function = self._create_graph_function(args, kwargs)
   2151           self._function_cache.primary[cache_key] = graph_function
   2152         return graph_function, args, kwargs

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2039             arg_names=arg_names,
   2040             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2041             capture_by_value=self._capture_by_value),
   2042         self._function_attributes,
   2043         # Tell the ConcreteFunction to clean up its graph once it goes out of

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    913                                           converted_func)
    914 
--> 915       func_outputs = python_func(*func_args, **func_kwargs)
    916 
    917       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    356         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    357         # the function a weak reference to itself to avoid a reference cycle.
--> 358         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    359     weak_wrapped_fn = weakref.ref(wrapped_fn)
    360 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
    513         layer, inputs=inputs, build_graph=False, training=training,
    514         saving=True):
--> 515       ret = method(*args, **kwargs)
    516     _restore_layer_losses(original_losses)
    517     return ret

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
    109         training,
    110         lambda: replace_training_and_call(True),
--> 111         lambda: replace_training_and_call(False))
    112 
    113   # Create arg spec for decorated function. If 'training' is not defined in the

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/utils/tf_utils.py in smart_cond(pred, true_fn, false_fn, name)
     57         pred, true_fn=true_fn, false_fn=false_fn, name=name)
     58   return smart_module.smart_cond(
---> 59       pred, true_fn=true_fn, false_fn=false_fn, name=name)
     60 
     61 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     52   if pred_value is not None:
     53     if pred_value:
---> 54       return true_fn()
     55     else:
     56       return false_fn()

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/utils.py in <lambda>()
    108     return tf_utils.smart_cond(
    109         training,
--> 110         lambda: replace_training_and_call(True),
    111         lambda: replace_training_and_call(False))
    112 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/utils.py in replace_training_and_call(training)
    104     def replace_training_and_call(training):
    105       set_training_arg(training, training_arg_index, args, kwargs)
--> 106       return wrapped_call(*args, **kwargs)
    107 
    108     return tf_utils.smart_cond(

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(inputs, *args, **kwargs)
    555   layer_call = _get_layer_call_method(layer)
    556   def call_and_return_conditional_losses(inputs, *args, **kwargs):
--> 557     return layer_call(inputs, *args, **kwargs), layer.get_losses_for(inputs)
    558   return _create_call_fn_decorator(layer, call_and_return_conditional_losses)
    559 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in call(self, inputs, training, mask)
    706     return self._run_internal_graph(
    707         inputs, training=training, mask=mask,
--> 708         convert_kwargs_to_constants=base_layer_utils.call_context().saving)
    709 
    710   def compute_output_shape(self, input_shape):

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in _run_internal_graph(self, inputs, training, mask, convert_kwargs_to_constants)
    858 
    859           # Compute outputs.
--> 860           output_tensors = layer(computed_tensors, **kwargs)
    861 
    862           # Update tensor_dict.

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    845                     outputs = base_layer_utils.mark_as_return(outputs, acd)
    846                 else:
--> 847                   outputs = call_fn(cast_inputs, *args, **kwargs)
    848 
    849             except errors.OperatorNotAllowedInGraphError as e:

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/utils.py in return_outputs_and_add_losses(*args, **kwargs)
     55     inputs = args[inputs_arg_index]
     56     args = args[inputs_arg_index + 1:]
---> 57     outputs, losses = fn(inputs, *args, **kwargs)
     58     layer.add_loss(losses, inputs)
     59     return outputs

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in __call__(self, *args, **kwargs)
    530   def __call__(self, *args, **kwargs):
    531     if not self.call_collection.tracing:
--> 532       self.call_collection.add_trace(*args, **kwargs)
    533     return super(LayerCall, self).__call__(*args, **kwargs)
    534 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in add_trace(self, *args, **kwargs)
    414         trace_with_training(False)
    415       else:
--> 416         fn.get_concrete_function(*args, **kwargs)
    417     self.tracing = False
    418 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in get_concrete_function(self, *args, **kwargs)
    536     if not self.call_collection.tracing:
    537       self.call_collection.add_trace(*args, **kwargs)
--> 538     return super(LayerCall, self).get_concrete_function(*args, **kwargs)
    539 
    540 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
    774       if self._stateful_fn is None:
    775         initializer_map = object_identity.ObjectIdentityDictionary()
--> 776         self._initialize(args, kwargs, add_initializers_to=initializer_map)
    777         self._initialize_uninitialized_variables(initializer_map)
    778 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    406     self._concrete_stateful_fn = (
    407         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 408             *args, **kwds))
    409 
    410     def invalid_creator_scope(*unused_args, **unused_kwds):

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   1846     if self.input_signature:
   1847       args, kwargs = None, None
-> 1848     graph_function, _, _ = self._maybe_define_function(args, kwargs)
   1849     return graph_function
   1850 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2148         graph_function = self._function_cache.primary.get(cache_key, None)
   2149         if graph_function is None:
-> 2150           graph_function = self._create_graph_function(args, kwargs)
   2151           self._function_cache.primary[cache_key] = graph_function
   2152         return graph_function, args, kwargs

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2039             arg_names=arg_names,
   2040             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2041             capture_by_value=self._capture_by_value),
   2042         self._function_attributes,
   2043         # Tell the ConcreteFunction to clean up its graph once it goes out of

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    913                                           converted_func)
    914 
--> 915       func_outputs = python_func(*func_args, **func_kwargs)
    916 
    917       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    356         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    357         # the function a weak reference to itself to avoid a reference cycle.
--> 358         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    359     weak_wrapped_fn = weakref.ref(wrapped_fn)
    360 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
    513         layer, inputs=inputs, build_graph=False, training=training,
    514         saving=True):
--> 515       ret = method(*args, **kwargs)
    516     _restore_layer_losses(original_losses)
    517     return ret

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(inputs, *args, **kwargs)
    555   layer_call = _get_layer_call_method(layer)
    556   def call_and_return_conditional_losses(inputs, *args, **kwargs):
--> 557     return layer_call(inputs, *args, **kwargs), layer.get_losses_for(inputs)
    558   return _create_call_fn_decorator(layer, call_and_return_conditional_losses)
    559 

<ipython-input-5-d4972c5cbebe> in call(self, inputs)
      4     def call(self, inputs):
      5         x, emb = inputs
----> 6         return Concatenate(axis=1)([x, emb])
      7 
      8     def compute_output_shape(self, input_shape):

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    815           # Build layer if applicable (if the `build` method has been
    816           # overridden).
--> 817           self._maybe_build(inputs)
    818           cast_inputs = self._maybe_cast_inputs(inputs)
    819 

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in _maybe_build(self, inputs)
   2139         # operations.
   2140         with tf_utils.maybe_init_scope(self):
-> 2141           self.build(input_shapes)
   2142       # We must set self.built since user defined build functions are not
   2143       # constrained to set self.built.

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/utils/tf_utils.py in wrapper(instance, input_shape)
    304     if input_shape is not None:
    305       input_shape = convert_shapes(input_shape, to_tuples=True)
--> 306     output_shape = fn(instance, input_shape)
    307     # Return shapes from `fn` as TensorShapes.
    308     if output_shape is not None:

~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/layers/merge.py in build(self, input_shape)
    389                        'inputs with matching shapes '
    390                        'except for the concat axis. '
--> 391                        'Got inputs shapes: %s' % (input_shape))
    392 
    393   def _merge_function(self, inputs):

ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 16, None, None), (None, 512, 6, 18)]

Когда я выхожу только из слоя MyConcat или только транспонирую слой свертки, все работает. Если я просто заменю слой MyConcat на Concatenate (axis = 1), все заработает. В model.summary формы слоя определены правильно. Пожалуйста, не предлагайте переключаться на формат "h5", да, он работает для этого фрагмента кода, но не работает где-либо еще в моем проекте, и рекомендуемое решение - переключиться на формат .tf.

1 Ответ

0 голосов
/ 28 мая 2020

Результат Conv2DTranpose имеет неопределенную форму в keras, это известная проблема. https://github.com/tensorflow/tensorflow/issues/8972

Добавление слоя Reshape после слоя Conv2DTranspose помогает keras определить правильную форму слоя и сохранить модель. (Хотя простой set_shape по какой-то причине не помогает.)

...