Я использую tenorflow-gpu == 2.0.0 Следующая модель не сохраняется:
inputs = Input(shape=(5,1,14))
emb = Input(shape=(512,6,18))
upconv1 = Conv2DTranspose(filters=16,
kernel_size=(6,5),
strides=(6,1),
data_format="channels_first")(inputs)
x = MyConcat()([upconv1, emb])
decoder = Model(inputs=[inputs, emb], outputs=x)
decoder.save("decoder.tf", save_format="tf")
Здесь MyConcat - мой настраиваемый слой:
class MyConcat(keras.layers.Layer):
def __init__(self):
super(MyConcat, self).__init__()
def call(self, inputs):
x, emb = inputs
return Concatenate(axis=1)([x, emb])
def compute_output_shape(self, input_shape):
shape = (None, input_shape[0][1] + input_shape[1][1],
input_shape[0][2], input_shape[0][3])
return shape
Ошибка message:
ValueError Traceback (most recent call last)
<ipython-input-10-1bd6aa1b259b> in <module>
10
11 decoder = Model(inputs=[inputs, emb], outputs=x)
---> 12 decoder.save("decoder.tf", save_format="tf")
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
973 """
974 saving.save_model(self, filepath, overwrite, include_optimizer, save_format,
--> 975 signatures, options)
976
977 def save_weights(self, filepath, overwrite=True, save_format=None):
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
113 else:
114 saved_model_save.save(model, filepath, overwrite, include_optimizer,
--> 115 signatures, options)
116
117
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options)
72 # default learning phase placeholder.
73 with K.learning_phase_scope(0):
---> 74 save_lib.save(model, filepath, signatures, options)
75
76 if not include_optimizer:
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py in save(obj, export_dir, signatures, options)
868 if signatures is None:
869 signatures = signature_serialization.find_function_to_export(
--> 870 checkpoint_graph_view)
871
872 signatures = signature_serialization.canonicalize_signatures(signatures)
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
62 # If the user did not specify signatures, check the root object for a function
63 # that can be made into a signature.
---> 64 functions = saveable_view.list_functions(saveable_view.root)
65 signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
66 if signature is not None:
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py in list_functions(self, obj)
139 if obj_functions is None:
140 obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
--> 141 self._serialization_cache)
142 self._functions[obj] = obj_functions
143 return obj_functions
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
2420 def _list_functions_for_serialization(self, serialization_cache):
2421 return (self._trackable_saved_model_saver
-> 2422 .list_functions_for_serialization(serialization_cache))
2423
2424
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
89 `ConcreteFunction`.
90 """
---> 91 fns = self.functions_to_serialize(serialization_cache)
92
93 # The parent AutoTrackable class saves all user-defined tf.functions, and
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
77 def functions_to_serialize(self, serialization_cache):
78 return (self._get_serialized_attributes(
---> 79 serialization_cache).functions_to_serialize)
80
81 def _get_serialized_attributes(self, serialization_cache):
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
92
93 object_dict, function_dict = self._get_serialized_attributes_internal(
---> 94 serialization_cache)
95
96 serialized_attr.set_and_validate_objects(object_dict)
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
51 objects, functions = (
52 super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
---> 53 serialization_cache))
54 functions['_default_save_signature'] = default_signature
55 return objects, functions
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
101 """Returns dictionary of serialized attributes."""
102 objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
--> 103 functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
104 # Attribute validator requires that the default save signature is added to
105 # function dict, even if the value is None.
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in wrap_layer_functions(layer, serialization_cache)
164 call_fn_with_losses = call_collection.add_function(
165 _wrap_call_and_conditional_losses(layer),
--> 166 '{}_layer_call_and_return_conditional_losses'.format(layer.name))
167 call_fn = call_collection.add_function(
168 _extract_outputs_from_fn(layer, call_fn_with_losses),
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in add_function(self, call_fn, name)
492 # Manually add traces for layers that have keyword arguments and have
493 # a fully defined input signature.
--> 494 self.add_trace(*self._input_signature)
495 return fn
496
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in add_trace(self, *args, **kwargs)
411 fn.get_concrete_function(*args, **kwargs)
412
--> 413 trace_with_training(True)
414 trace_with_training(False)
415 else:
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in trace_with_training(value, fn)
409 utils.set_training_arg(value, self._training_arg_index, args, kwargs)
410 with K.learning_phase_scope(value):
--> 411 fn.get_concrete_function(*args, **kwargs)
412
413 trace_with_training(True)
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in get_concrete_function(self, *args, **kwargs)
536 if not self.call_collection.tracing:
537 self.call_collection.add_trace(*args, **kwargs)
--> 538 return super(LayerCall, self).get_concrete_function(*args, **kwargs)
539
540
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
774 if self._stateful_fn is None:
775 initializer_map = object_identity.ObjectIdentityDictionary()
--> 776 self._initialize(args, kwargs, add_initializers_to=initializer_map)
777 self._initialize_uninitialized_variables(initializer_map)
778
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
406 self._concrete_stateful_fn = (
407 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 408 *args, **kwds))
409
410 def invalid_creator_scope(*unused_args, **unused_kwds):
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
1846 if self.input_signature:
1847 args, kwargs = None, None
-> 1848 graph_function, _, _ = self._maybe_define_function(args, kwargs)
1849 return graph_function
1850
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2148 graph_function = self._function_cache.primary.get(cache_key, None)
2149 if graph_function is None:
-> 2150 graph_function = self._create_graph_function(args, kwargs)
2151 self._function_cache.primary[cache_key] = graph_function
2152 return graph_function, args, kwargs
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2039 arg_names=arg_names,
2040 override_flat_arg_shapes=override_flat_arg_shapes,
-> 2041 capture_by_value=self._capture_by_value),
2042 self._function_attributes,
2043 # Tell the ConcreteFunction to clean up its graph once it goes out of
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
913 converted_func)
914
--> 915 func_outputs = python_func(*func_args, **func_kwargs)
916
917 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
356 # __wrapped__ allows AutoGraph to swap in a converted function. We give
357 # the function a weak reference to itself to avoid a reference cycle.
--> 358 return weak_wrapped_fn().__wrapped__(*args, **kwds)
359 weak_wrapped_fn = weakref.ref(wrapped_fn)
360
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
513 layer, inputs=inputs, build_graph=False, training=training,
514 saving=True):
--> 515 ret = method(*args, **kwargs)
516 _restore_layer_losses(original_losses)
517 return ret
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
109 training,
110 lambda: replace_training_and_call(True),
--> 111 lambda: replace_training_and_call(False))
112
113 # Create arg spec for decorated function. If 'training' is not defined in the
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/utils/tf_utils.py in smart_cond(pred, true_fn, false_fn, name)
57 pred, true_fn=true_fn, false_fn=false_fn, name=name)
58 return smart_module.smart_cond(
---> 59 pred, true_fn=true_fn, false_fn=false_fn, name=name)
60
61
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
52 if pred_value is not None:
53 if pred_value:
---> 54 return true_fn()
55 else:
56 return false_fn()
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/utils.py in <lambda>()
108 return tf_utils.smart_cond(
109 training,
--> 110 lambda: replace_training_and_call(True),
111 lambda: replace_training_and_call(False))
112
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/utils.py in replace_training_and_call(training)
104 def replace_training_and_call(training):
105 set_training_arg(training, training_arg_index, args, kwargs)
--> 106 return wrapped_call(*args, **kwargs)
107
108 return tf_utils.smart_cond(
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(inputs, *args, **kwargs)
555 layer_call = _get_layer_call_method(layer)
556 def call_and_return_conditional_losses(inputs, *args, **kwargs):
--> 557 return layer_call(inputs, *args, **kwargs), layer.get_losses_for(inputs)
558 return _create_call_fn_decorator(layer, call_and_return_conditional_losses)
559
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in call(self, inputs, training, mask)
706 return self._run_internal_graph(
707 inputs, training=training, mask=mask,
--> 708 convert_kwargs_to_constants=base_layer_utils.call_context().saving)
709
710 def compute_output_shape(self, input_shape):
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in _run_internal_graph(self, inputs, training, mask, convert_kwargs_to_constants)
858
859 # Compute outputs.
--> 860 output_tensors = layer(computed_tensors, **kwargs)
861
862 # Update tensor_dict.
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
845 outputs = base_layer_utils.mark_as_return(outputs, acd)
846 else:
--> 847 outputs = call_fn(cast_inputs, *args, **kwargs)
848
849 except errors.OperatorNotAllowedInGraphError as e:
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/utils.py in return_outputs_and_add_losses(*args, **kwargs)
55 inputs = args[inputs_arg_index]
56 args = args[inputs_arg_index + 1:]
---> 57 outputs, losses = fn(inputs, *args, **kwargs)
58 layer.add_loss(losses, inputs)
59 return outputs
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in __call__(self, *args, **kwargs)
530 def __call__(self, *args, **kwargs):
531 if not self.call_collection.tracing:
--> 532 self.call_collection.add_trace(*args, **kwargs)
533 return super(LayerCall, self).__call__(*args, **kwargs)
534
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in add_trace(self, *args, **kwargs)
414 trace_with_training(False)
415 else:
--> 416 fn.get_concrete_function(*args, **kwargs)
417 self.tracing = False
418
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in get_concrete_function(self, *args, **kwargs)
536 if not self.call_collection.tracing:
537 self.call_collection.add_trace(*args, **kwargs)
--> 538 return super(LayerCall, self).get_concrete_function(*args, **kwargs)
539
540
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
774 if self._stateful_fn is None:
775 initializer_map = object_identity.ObjectIdentityDictionary()
--> 776 self._initialize(args, kwargs, add_initializers_to=initializer_map)
777 self._initialize_uninitialized_variables(initializer_map)
778
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
406 self._concrete_stateful_fn = (
407 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 408 *args, **kwds))
409
410 def invalid_creator_scope(*unused_args, **unused_kwds):
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
1846 if self.input_signature:
1847 args, kwargs = None, None
-> 1848 graph_function, _, _ = self._maybe_define_function(args, kwargs)
1849 return graph_function
1850
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2148 graph_function = self._function_cache.primary.get(cache_key, None)
2149 if graph_function is None:
-> 2150 graph_function = self._create_graph_function(args, kwargs)
2151 self._function_cache.primary[cache_key] = graph_function
2152 return graph_function, args, kwargs
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2039 arg_names=arg_names,
2040 override_flat_arg_shapes=override_flat_arg_shapes,
-> 2041 capture_by_value=self._capture_by_value),
2042 self._function_attributes,
2043 # Tell the ConcreteFunction to clean up its graph once it goes out of
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
913 converted_func)
914
--> 915 func_outputs = python_func(*func_args, **func_kwargs)
916
917 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
356 # __wrapped__ allows AutoGraph to swap in a converted function. We give
357 # the function a weak reference to itself to avoid a reference cycle.
--> 358 return weak_wrapped_fn().__wrapped__(*args, **kwds)
359 weak_wrapped_fn = weakref.ref(wrapped_fn)
360
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
513 layer, inputs=inputs, build_graph=False, training=training,
514 saving=True):
--> 515 ret = method(*args, **kwargs)
516 _restore_layer_losses(original_losses)
517 return ret
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(inputs, *args, **kwargs)
555 layer_call = _get_layer_call_method(layer)
556 def call_and_return_conditional_losses(inputs, *args, **kwargs):
--> 557 return layer_call(inputs, *args, **kwargs), layer.get_losses_for(inputs)
558 return _create_call_fn_decorator(layer, call_and_return_conditional_losses)
559
<ipython-input-5-d4972c5cbebe> in call(self, inputs)
4 def call(self, inputs):
5 x, emb = inputs
----> 6 return Concatenate(axis=1)([x, emb])
7
8 def compute_output_shape(self, input_shape):
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
815 # Build layer if applicable (if the `build` method has been
816 # overridden).
--> 817 self._maybe_build(inputs)
818 cast_inputs = self._maybe_cast_inputs(inputs)
819
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in _maybe_build(self, inputs)
2139 # operations.
2140 with tf_utils.maybe_init_scope(self):
-> 2141 self.build(input_shapes)
2142 # We must set self.built since user defined build functions are not
2143 # constrained to set self.built.
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/utils/tf_utils.py in wrapper(instance, input_shape)
304 if input_shape is not None:
305 input_shape = convert_shapes(input_shape, to_tuples=True)
--> 306 output_shape = fn(instance, input_shape)
307 # Return shapes from `fn` as TensorShapes.
308 if output_shape is not None:
~/.pyenv/versions/3.7.6/lib/python3.7/site-packages/tensorflow_core/python/keras/layers/merge.py in build(self, input_shape)
389 'inputs with matching shapes '
390 'except for the concat axis. '
--> 391 'Got inputs shapes: %s' % (input_shape))
392
393 def _merge_function(self, inputs):
ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 16, None, None), (None, 512, 6, 18)]
Когда я выхожу только из слоя MyConcat или только транспонирую слой свертки, все работает. Если я просто заменю слой MyConcat на Concatenate (axis = 1), все заработает. В model.summary формы слоя определены правильно. Пожалуйста, не предлагайте переключаться на формат "h5", да, он работает для этого фрагмента кода, но не работает где-либо еще в моем проекте, и рекомендуемое решение - переключиться на формат .tf.