Я работаю над https://github.com/RaphaelMeudec/deblur-gan/, чтобы улучшить их DeblurGAN. Моя цель - преобразовать H5-модель обучения DeblurGAN в формат TFLite.
В моей модели H5 я определил пользовательский слой с именем ReflectionPadding2D
(код приведен ниже). Для этого преобразования я использую следующие команды Python:
g.save(os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), include_optimizer=False)
model = tf.keras.models.load_model(
os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)))
converter = tf.lite.TFLiteConverter.from_keras_model_file(model, custom_objects={'ReflectionPadding2D': ReflectionPadding2D})
tflite_model = converter.convert()
open(os.path.join(save_dir, 'full_generator_{}_{}.tflite'.format(epoch_number, current_loss)),
"wb").write(tflite_model)
Как видите, я использую custom_objects
. ReflectionPadding2D
- это просто класс (не объект), импортированный благодаря from deblurgan.layer_utils import ReflectionPadding2D
.
Поскольку моя модель содержит мой пользовательский слой ReflectionPadding2D
, вышеприведенные команды выдают следующую ошибку:
ValueError: Неизвестный слой: ReflectionPadding2D
Код для сохранения моей модели H5
Возможно, мне следует добавить строку для включения чего-либо в сохраненный H5, чтобы разрешить его преобразование в формат TFLite? Вот код, который я использую для сохранения своей модели H5:
g.save(os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), include_optimizer=False)
Код ReflectionPadding2D
(посмотрите на метод call
)
def spatial_reflection_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
"""
Pad the 2nd and 3rd dimensions of a 4D tensor.
:param x: Input tensor
:param padding: Shape of padding to use
:param data_format: Tensorflow vs Theano convention ('channels_last', 'channels_first')
:return: Tensorflow tensor
"""
assert len(padding) == 2
assert len(padding[0]) == 2
assert len(padding[1]) == 2
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
if data_format == 'channels_first':
pattern = [[0, 0],
[0, 0],
list(padding[0]),
list(padding[1])]
else:
pattern = [[0, 0],
list(padding[0]), list(padding[1]),
[0, 0]]
return tf.pad(x, pattern, "REFLECT")
class ReflectionPadding2D(Layer):
def __init__(self,
padding=(1, 1),
data_format=None,
**kwargs):
super(ReflectionPadding2D, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
if isinstance(padding, int):
self.padding = ((padding, padding), (padding, padding))
elif hasattr(padding, '__len__'):
if len(padding) != 2:
raise ValueError('`padding` should have two elements. '
'Found: ' + str(padding))
height_padding = conv_utils.normalize_tuple(padding[0], 2,
'1st entry of padding')
width_padding = conv_utils.normalize_tuple(padding[1], 2,
'2nd entry of padding')
self.padding = (height_padding, width_padding)
else:
raise ValueError('`padding` should be either an int, '
'a tuple of 2 ints '
'(symmetric_height_pad, symmetric_width_pad), '
'or a tuple of 2 tuples of 2 ints '
'((top_pad, bottom_pad), (left_pad, right_pad)). '
'Found: ' + str(padding))
self.input_spec = InputSpec(ndim=4)
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
if input_shape[2] is not None:
rows = input_shape[2] + self.padding[0][0] + self.padding[0][1]
else:
rows = None
if input_shape[3] is not None:
cols = input_shape[3] + self.padding[1][0] + self.padding[1][1]
else:
cols = None
return (input_shape[0],
input_shape[1],
rows,
cols)
elif self.data_format == 'channels_last':
if input_shape[1] is not None:
rows = input_shape[1] + self.padding[0][0] + self.padding[0][1]
else:
rows = None
if input_shape[2] is not None:
cols = input_shape[2] + self.padding[1][0] + self.padding[1][1]
else:
cols = None
return (input_shape[0],
rows,
cols,
input_shape[3])
def call(self, inputs):
return spatial_reflection_2d_padding(inputs,
padding=self.padding,
data_format=self.data_format)
def get_config(self):
config = {'padding': self.padding,
'data_format': self.data_format}
base_config = super(ReflectionPadding2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Последний вопрос
Что я должен изменить, чтобы разрешить преобразование в формат TFLite?