load_weights возвращает не вызываемое значение - PullRequest
0 голосов
/ 27 марта 2020

Я использую подклассы моделей для определения своей модели и загрузки весов для моей нейронной сети с помощью специальной функции обучения. Из-за подклассов я не могу использовать model.load и model.save.

Мой подкласс для загрузки весов выглядит следующим образом:

def load_W_1(self, filepath, inputs):
    x=self.call(inputs)
    x=self.load_weights(filepath)
    return x


def call(self, inputs, training=True):

    x = self.conv_1(inputs)
    x = self.leakyrelu_1(x)
    x = self.dropout_1(x, training)

    x = self.conv_2(x)
    x = self.batch_norm_1(x, training)
    x = self.leakyrelu_2(x)
    x = self.dropout_2(x, training)

    x = self.flatten(x)

    return self.out(x)

Причина, по которой я называю self.call в моем load_weights() функция в противном случае у меня есть проблемы с количеством слоев в load_weights().

Чтобы вызвать мою функцию, я затем выполняю следующие строки:

discriminator=discriminator.load_W_1(filepath2, my_image)

Здесь filepath2 - это расположение моего файла .h5, в то время как my_image - это произвольный тензор для соответствия размерам (Вы пытаетесь загрузить файл весов, содержащий 11 слоев, в модель с 10 слоями.)

Но в моей обучающей функции я имею строка для вызова обучения дискриминатора как

disc_loss += train_discriminator(images)

, которая приводит к следующей функции

def train_discriminator(images):
        noise = tf.random.normal([batch_size, latent_dim])

        with tf.GradientTape() as disc_tape:
            generated_imgs = generator(noise, training=True)

Но указывая на:

disc_loss += train_discriminator(images)

Я получаю ошибку, которая говорит:

TypeError: 'NoneType' object is not callable

Я предполагаю, что это потому, что модель заменяется ее весами, когда я вызываю функцию load_weights, но я не знаю, как ее обойти. Мне нужно позволить ему тренироваться с этими обученными h5s.

полный путь трассировки


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-110-cdbe6daa184b> in <module>()
    140 
    141 if __name__ == "__main__":
--> 142     train()
    143     #train_on_batch

9 frames
<ipython-input-110-cdbe6daa184b> in train()
    104             g_temp,d_temp, temp_all=list(),list(),list()
    105 
--> 106             disc_loss += train_discriminator(images)
    107             d_temp.append(disc_loss)
    108             d_all.append(disc_loss)

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    455 
    456     tracing_count = self._get_tracing_count()
--> 457     result = self._call(*args, **kwds)
    458     if tracing_count == self._get_tracing_count():
    459       self._call_counter.called_without_tracing()

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
    501       # This is the first call of __call__, so we have to initialize.
    502       initializer_map = object_identity.ObjectIdentityDictionary()
--> 503       self._initialize(args, kwds, add_initializers_to=initializer_map)
    504     finally:
    505       # At this point we know that the initialization is complete (or less

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    406     self._concrete_stateful_fn = (
    407         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 408             *args, **kwds))
    409 
    410     def invalid_creator_scope(*unused_args, **unused_kwds):

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   1846     if self.input_signature:
   1847       args, kwargs = None, None
-> 1848     graph_function, _, _ = self._maybe_define_function(args, kwargs)
   1849     return graph_function
   1850 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2148         graph_function = self._function_cache.primary.get(cache_key, None)
   2149         if graph_function is None:
-> 2150           graph_function = self._create_graph_function(args, kwargs)
   2151           self._function_cache.primary[cache_key] = graph_function
   2152         return graph_function, args, kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2039             arg_names=arg_names,
   2040             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2041             capture_by_value=self._capture_by_value),
   2042         self._function_attributes,
   2043         # Tell the ConcreteFunction to clean up its graph once it goes out of

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    913                                           converted_func)
    914 
--> 915       func_outputs = python_func(*func_args, **func_kwargs)
    916 
    917       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    356         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    357         # the function a weak reference to itself to avoid a reference cycle.
--> 358         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    359     weak_wrapped_fn = weakref.ref(wrapped_fn)
    360 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in wrapper(*args, **kwargs)
    903           except Exception as e:  # pylint:disable=broad-except
    904             if hasattr(e, "ag_error_metadata"):
--> 905               raise e.ag_error_metadata.to_exception(e)
    906             else:
    907               raise

TypeError: in converted code:

    <ipython-input-110-cdbe6daa184b>:59 train_discriminator  *
        generated_output = discriminator(generated_imgs, training=True)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/autograph/impl/api.py:394 converted_call
        return py_builtins.overload_of(f)(*args, **kwargs)

    TypeError: 'NoneType' object is not callable

Edir: только что изменен discriminator=discriminator.load_W_1(filepath2, my_image)

на discriminator.load_W_1(filepath2, my_image)

Теперь он тренируется, но с начальными значениями, кажется, нет загрузки весов

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...