Невозможно сохранить подклассную модель TensorFlow 2.1 - __call __ () отсутствует 1 обязательный позиционный аргумент: 'x' - PullRequest
1 голос
/ 03 марта 2020

Я не могу сохранить свою модель с помощью метода tensorflow.keras.Model classes 'save ... Я также пытался использовать tensorflow.saved_model и tensorflow.keras.models.save_model, которые также не работали. Каждый раз, когда я вижу одну и ту же ошибку:

  File ".../keras/saving/saving_utils.py", line 150, in _wrapped_model
    outputs_list = nest.flatten(model(inputs=inputs, training=False))
TypeError: __call__() missing 1 required positional argument: 'x'

Как сохранить свою модель с подклассами, чтобы я мог использовать tensorflow.compat.v1.lite.TFLiteConverter на ней?

Операционная система:

Darwin Alexs-MacBook-Pro.local 18.7.0 Darwin Kernel Version 18.7.0: Thu Jan 23 06:52:12 PST 2020; root:xnu-4903.278.25~1/RELEASE_X86_64 x86_64

Версия Tensorflow: 2.1.0

Код:

from tensorflow import saved_model
import tensorflow.keras as keras

class TrainTest(keras.Model):
    def __init__(self, input_dim=1, hidden_dim=1, **kwargs):
        super(TrainTest, self).__init__()
        self.dense1 = keras.layers.Dense(input_dim, activation=keras.activations.relu)
        self.dense2 = keras.layers.Dense(hidden_dim, activation=keras.activations.relu)
        self.dense3 = keras.layers.Dense(1, activation=keras.activations.linear)

    def __call__(self, x, **kwargs):
        x = self.dense1(x)
        x = self.dense2(x)
        return self.dense3(x)

if __name__ == "__main__":
    (train_x, train_y), (test_x, test_y) = keras.datasets.boston_housing.load_data(test_split=0.1)
    model = TrainTest(input_dim=train_x.shape[1],
            hidden_dim=int(train_x.shape[1] * 1.5))
    model.compile(optimizer=keras.optimizers.Adam(0.001),
            loss=keras.losses.MeanSquaredError(),
            metrics=['mape'])
    model.fit(train_x, train_y,
            batch_size=32,
            epochs=10,
            validation_split=0.1)
    path = 'test_model.pb'
    model.save(path, save_format='tf')
    # saved_model.save(model, path) 
    # keras.models.save_model(model, path)

Stacktrace:

Traceback (most recent call last):
  File "TrainTest.py", line 71, in <module>
    model.save('test', save_format='tf')
  File ".../keras/engine/network.py", line 1008, in save
    signatures, options)
  File ".../keras/saving/save.py", line 115, in save_model
    signatures, options)
  File ".../keras/saving/saved_model/save.py", line 78, in save
    save_lib.save(model, filepath, signatures, options)
  File ".../saved_model/save.py", line 886, in save
    checkpoint_graph_view)
  File ".../saved_model/signature_serialization.py", line 74, in find_function_to_export
    functions = saveable_view.list_functions(saveable_view.root)
  File ".../saved_model/save.py", line 142, in list_functions
    self._serialization_cache)
  File ".../keras/engine/base_layer.py", line 2420, in _list_functions_for_serialization
    .list_functions_for_serialization(serialization_cache))
  File ".../keras/saving/saved_model/base_serialization.py", line 91, in list_functions_for_serialization
    fns = self.functions_to_serialize(serialization_cache)
  File ".../keras/saving/saved_model/layer_serialization.py", line 80, in functions_to_serialize
    serialization_cache).functions_to_serialize)
  File ".../keras/saving/saved_model/layer_serialization.py", line 95, in _get_serialized_attributes
    serialization_cache)
  File ".../keras/saving/saved_model/model_serialization.py", line 47, in _get_serialized_attributes_internal
    default_signature = save_impl.default_save_signature(self.obj)
  File ".../keras/saving/saved_model/save_impl.py", line 212, in default_save_signature
    fn.get_concrete_function()
  File ".../eager/def_function.py", line 909, in get_concrete_function
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File ".../eager/def_function.py", line 497, in _initialize
    *args, **kwds))
  File ".../eager/function.py", line 2389, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File ".../eager/function.py", line 2703, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File ".../eager/function.py", line 2593, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File ".../framework/func_graph.py", line 978, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File ".../eager/def_function.py", line 439, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File ".../keras/saving/saving_utils.py", line 150, in _wrapped_model
    outputs_list = nest.flatten(model(inputs=inputs, training=False))
TypeError: __call__() missing 1 required positional argument: 'x'

1 Ответ

1 голос
/ 03 марта 2020

Я понял это. Класс моей модели не был определен должным образом ... функция, вызываемая при выводе:

def call(self, x):

, а не

def __call__(self, x):
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...