Пользовательский объект TFLite: ValueError: Неизвестный слой: ReflectionPadding2D - PullRequest
0 голосов
/ 29 октября 2019

Я работаю над https://github.com/RaphaelMeudec/deblur-gan/, чтобы улучшить их DeblurGAN. Моя цель - преобразовать H5-модель обучения DeblurGAN в формат TFLite.

В моей модели H5 я определил пользовательский слой с именем ReflectionPadding2D (код приведен ниже). Для этого преобразования я использую следующие команды Python:

g.save(os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), include_optimizer=False)
model = tf.keras.models.load_model(
    os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)))
converter = tf.lite.TFLiteConverter.from_keras_model_file(model, custom_objects={'ReflectionPadding2D': ReflectionPadding2D})
tflite_model = converter.convert()
open(os.path.join(save_dir, 'full_generator_{}_{}.tflite'.format(epoch_number, current_loss)),
     "wb").write(tflite_model)

Как видите, я использую custom_objects. ReflectionPadding2D - это просто класс (не объект), импортированный благодаря from deblurgan.layer_utils import ReflectionPadding2D.

Поскольку моя модель содержит мой пользовательский слой ReflectionPadding2D, вышеприведенные команды выдают следующую ошибку:

ValueError: Неизвестный слой: ReflectionPadding2D

Код для сохранения моей модели H5

Возможно, мне следует добавить строку для включения чего-либо в сохраненный H5, чтобы разрешить его преобразование в формат TFLite? Вот код, который я использую для сохранения своей модели H5:

g.save(os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), include_optimizer=False)

Код ReflectionPadding2D (посмотрите на метод call)

def spatial_reflection_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
    """
    Pad the 2nd and 3rd dimensions of a 4D tensor.

    :param x: Input tensor
    :param padding: Shape of padding to use
    :param data_format: Tensorflow vs Theano convention ('channels_last', 'channels_first')
    :return: Tensorflow tensor
    """
    assert len(padding) == 2
    assert len(padding[0]) == 2
    assert len(padding[1]) == 2
    if data_format is None:
        data_format = image_data_format()
    if data_format not in {'channels_first', 'channels_last'}:
        raise ValueError('Unknown data_format ' + str(data_format))

    if data_format == 'channels_first':
        pattern = [[0, 0],
                   [0, 0],
                   list(padding[0]),
                   list(padding[1])]
    else:
        pattern = [[0, 0],
                   list(padding[0]), list(padding[1]),
                   [0, 0]]
    return tf.pad(x, pattern, "REFLECT")


class ReflectionPadding2D(Layer):

    def __init__(self,
                 padding=(1, 1),
                 data_format=None,
                 **kwargs):
        super(ReflectionPadding2D, self).__init__(**kwargs)
        self.data_format = conv_utils.normalize_data_format(data_format)
        if isinstance(padding, int):
            self.padding = ((padding, padding), (padding, padding))
        elif hasattr(padding, '__len__'):
            if len(padding) != 2:
                raise ValueError('`padding` should have two elements. '
                                 'Found: ' + str(padding))
            height_padding = conv_utils.normalize_tuple(padding[0], 2,
                                                        '1st entry of padding')
            width_padding = conv_utils.normalize_tuple(padding[1], 2,
                                                       '2nd entry of padding')
            self.padding = (height_padding, width_padding)
        else:
            raise ValueError('`padding` should be either an int, '
                             'a tuple of 2 ints '
                             '(symmetric_height_pad, symmetric_width_pad), '
                             'or a tuple of 2 tuples of 2 ints '
                             '((top_pad, bottom_pad), (left_pad, right_pad)). '
                             'Found: ' + str(padding))
        self.input_spec = InputSpec(ndim=4)

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            if input_shape[2] is not None:
                rows = input_shape[2] + self.padding[0][0] + self.padding[0][1]
            else:
                rows = None
            if input_shape[3] is not None:
                cols = input_shape[3] + self.padding[1][0] + self.padding[1][1]
            else:
                cols = None
            return (input_shape[0],
                    input_shape[1],
                    rows,
                    cols)
        elif self.data_format == 'channels_last':
            if input_shape[1] is not None:
                rows = input_shape[1] + self.padding[0][0] + self.padding[0][1]
            else:
                rows = None
            if input_shape[2] is not None:
                cols = input_shape[2] + self.padding[1][0] + self.padding[1][1]
            else:
                cols = None
            return (input_shape[0],
                    rows,
                    cols,
                    input_shape[3])

    def call(self, inputs):
        return spatial_reflection_2d_padding(inputs,
                                             padding=self.padding,
                                             data_format=self.data_format)

    def get_config(self):
        config = {'padding': self.padding,
                  'data_format': self.data_format}
        base_config = super(ReflectionPadding2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

Последний вопрос

Что я должен изменить, чтобы разрешить преобразование в формат TFLite?

1 Ответ

0 голосов
/ 29 октября 2019

Этот код работает:

g.save(os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), include_optimizer=False)

model = tf.keras.models.load_model(
    os.path.join(save_dir, 'full_generator_{}_{}.h5'.format(epoch_number, current_loss)), custom_objects={'ReflectionPadding2D': ReflectionPadding2D})
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open(os.path.join(save_dir, 'full_generator_{}_{}.tflite'.format(epoch_number, current_loss)),
     "wb").write(tflite_model)
...