Я использую подклассы моделей для определения своей модели и загрузки весов для моей нейронной сети с помощью специальной функции обучения. Из-за подклассов я не могу использовать model.load
и model.save
.
Мой подкласс для загрузки весов выглядит следующим образом:
def load_W_1(self, filepath, inputs):
x=self.call(inputs)
x=self.load_weights(filepath)
return x
def call(self, inputs, training=True):
x = self.conv_1(inputs)
x = self.leakyrelu_1(x)
x = self.dropout_1(x, training)
x = self.conv_2(x)
x = self.batch_norm_1(x, training)
x = self.leakyrelu_2(x)
x = self.dropout_2(x, training)
x = self.flatten(x)
return self.out(x)
Причина, по которой я называю self.call
в моем load_weights()
функция в противном случае у меня есть проблемы с количеством слоев в load_weights()
.
Чтобы вызвать мою функцию, я затем выполняю следующие строки:
discriminator=discriminator.load_W_1(filepath2, my_image)
Здесь filepath2
- это расположение моего файла .h5, в то время как my_image - это произвольный тензор для соответствия размерам (Вы пытаетесь загрузить файл весов, содержащий 11 слоев, в модель с 10 слоями.)
Но в моей обучающей функции я имею строка для вызова обучения дискриминатора как
disc_loss += train_discriminator(images)
, которая приводит к следующей функции
def train_discriminator(images):
noise = tf.random.normal([batch_size, latent_dim])
with tf.GradientTape() as disc_tape:
generated_imgs = generator(noise, training=True)
Но указывая на:
disc_loss += train_discriminator(images)
Я получаю ошибку, которая говорит:
TypeError: 'NoneType' object is not callable
Я предполагаю, что это потому, что модель заменяется ее весами, когда я вызываю функцию load_weights, но я не знаю, как ее обойти. Мне нужно позволить ему тренироваться с этими обученными h5s.
полный путь трассировки
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-110-cdbe6daa184b> in <module>()
140
141 if __name__ == "__main__":
--> 142 train()
143 #train_on_batch
9 frames
<ipython-input-110-cdbe6daa184b> in train()
104 g_temp,d_temp, temp_all=list(),list(),list()
105
--> 106 disc_loss += train_discriminator(images)
107 d_temp.append(disc_loss)
108 d_all.append(disc_loss)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
455
456 tracing_count = self._get_tracing_count()
--> 457 result = self._call(*args, **kwds)
458 if tracing_count == self._get_tracing_count():
459 self._call_counter.called_without_tracing()
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
501 # This is the first call of __call__, so we have to initialize.
502 initializer_map = object_identity.ObjectIdentityDictionary()
--> 503 self._initialize(args, kwds, add_initializers_to=initializer_map)
504 finally:
505 # At this point we know that the initialization is complete (or less
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
406 self._concrete_stateful_fn = (
407 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 408 *args, **kwds))
409
410 def invalid_creator_scope(*unused_args, **unused_kwds):
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
1846 if self.input_signature:
1847 args, kwargs = None, None
-> 1848 graph_function, _, _ = self._maybe_define_function(args, kwargs)
1849 return graph_function
1850
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2148 graph_function = self._function_cache.primary.get(cache_key, None)
2149 if graph_function is None:
-> 2150 graph_function = self._create_graph_function(args, kwargs)
2151 self._function_cache.primary[cache_key] = graph_function
2152 return graph_function, args, kwargs
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2039 arg_names=arg_names,
2040 override_flat_arg_shapes=override_flat_arg_shapes,
-> 2041 capture_by_value=self._capture_by_value),
2042 self._function_attributes,
2043 # Tell the ConcreteFunction to clean up its graph once it goes out of
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
913 converted_func)
914
--> 915 func_outputs = python_func(*func_args, **func_kwargs)
916
917 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
356 # __wrapped__ allows AutoGraph to swap in a converted function. We give
357 # the function a weak reference to itself to avoid a reference cycle.
--> 358 return weak_wrapped_fn().__wrapped__(*args, **kwds)
359 weak_wrapped_fn = weakref.ref(wrapped_fn)
360
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in wrapper(*args, **kwargs)
903 except Exception as e: # pylint:disable=broad-except
904 if hasattr(e, "ag_error_metadata"):
--> 905 raise e.ag_error_metadata.to_exception(e)
906 else:
907 raise
TypeError: in converted code:
<ipython-input-110-cdbe6daa184b>:59 train_discriminator *
generated_output = discriminator(generated_imgs, training=True)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/autograph/impl/api.py:394 converted_call
return py_builtins.overload_of(f)(*args, **kwargs)
TypeError: 'NoneType' object is not callable
Edir: только что изменен discriminator=discriminator.load_W_1(filepath2, my_image)
на discriminator.load_W_1(filepath2, my_image)
Теперь он тренируется, но с начальными значениями, кажется, нет загрузки весов