Как сохранить / восстановить большую модель в tenorflow 2.0 с керасом? - PullRequest
7 голосов
/ 30 апреля 2019

У меня есть большая специальная модель, сделанная с новым tenorflow 2.0 и смешивающими керасами и тензорным потоком.Я хочу сохранить его (архитектура и вес).Точная команда для воспроизведения:

import tensorflow as tf


OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result


def Generator():
  down_stack = [
    downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
    downsample(128, 4), # (bs, 64, 64, 128)
    downsample(256, 4), # (bs, 32, 32, 256)
    downsample(512, 4), # (bs, 16, 16, 512)
    downsample(512, 4), # (bs, 8, 8, 512)
    downsample(512, 4), # (bs, 4, 4, 512)
    downsample(512, 4), # (bs, 2, 2, 512)
    downsample(512, 4), # (bs, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
    upsample(512, 4), # (bs, 16, 16, 1024)
    upsample(256, 4), # (bs, 32, 32, 512)
    upsample(128, 4), # (bs, 64, 64, 256)
    upsample(64, 4), # (bs, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') # (bs, 256, 256, 3)

  concat = tf.keras.layers.Concatenate()

  inputs = tf.keras.layers.Input(shape=[None,None,3])
  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = concat([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

generator = Generator()
generator.summary()

generator.save('generator.h5')
generator_loaded = tf.keras.models.load_model('generator.h5')

Мне удается сохранить модель с помощью:

generator.save('generator.h5')

Но когда я пытаюсь загрузить ее с помощью:

generator_loaded = tf.keras.models.load_model('generator.h5')

Itникогда не заканчивается (нет сообщения об ошибке).Может модель слишком большая?Я пытался сохранить как JSON с model.to_json(), а также с полным API tf.keras.models.save_model(), но та же проблема, невозможно загрузить его (или, по крайней мере, слишком долго).

Та же проблема в Windows / Linux ис / без графического процессора.

Сохранение и восстановление хорошо работают с полными Keras и простой моделью.

Редактировать

Ответы [ 4 ]

0 голосов
/ 17 мая 2019

Мне удалось сохранить и загрузить пользовательские модели, реализовав функции, аналогичные последовательной модели в Keras.

Ключевыми функциями являются CustomModel.get_config() CustomModel.from_config(), которые также должны существовать на любом из ваших пользовательских слоев (аналогично приведенным ниже функциям, но смотрите слои keras, если вы хотите лучшего понимания):

# In the CustomModel class    
def get_config(self):
    layer_configs = []
    for layer in self.layers:
        layer_configs.append({
            'class_name': layer.__class__.__name__,
            'config': layer.get_config()
        })
    config = {
        'name': self.name,
        'layers': copy.deepcopy(layer_configs),
        "arg1": self.arg1,
        ...
    }
    if self._build_input_shape:
        config['build_input_shape'] = self._build_input_shape
    return config

@classmethod
def from_config(cls, config, custom_objects=None):
    from tensorflow.python.keras import layers as layer_module
    if custom_objects is None:
        custom_objects = {'CustomLayer1Class': CustomLayer1Class, ...}
    else:
        custom_objects = dict(custom_objects, **{'CustomLayer1Class': CustomLayer1Class, ...})

    if 'name' in config:
        name = config['name']
        build_input_shape = config.get('build_input_shape')
        layer_configs = config['layers']
    else:
        name = None
        build_input_shape = None
        layer_configs = config
    model = cls(name=name,
                arg1=config['arg1'],
                should_build_graph=False,
                ...)
    for layer_config in tqdm(layer_configs, 'Loading Layers'):
        layer = layer_module.deserialize(layer_config,
                                         custom_objects=custom_objects)
        model.add(layer) # This function looks at the name of the layers to place them in the right order
    if not model.inputs and build_input_shape:
        model.build(build_input_shape)
    if not model._is_graph_network:
        # Still needs to be built when passed input data.
        model.built = False
    return model

Я также добавил функцию CustomModel.add(), которая добавляет слои один за другим из их конфигурации. Также параметр should_build_graph=False, который гарантирует, что вы не строите график в __init__() при вызове cls().

Тогда функция CustomModel.save() выглядит следующим образом:

    def save(self, filepath, overwrite=True, include_optimizer=True, **kwargs):
        from tensorflow.python.keras.models import save_model  
        save_model(self, filepath, overwrite, include_optimizer)

После этого вы можете сохранить, используя:

model.save("model.h5")
new_model = keras.models.load_model('model.h5',
                                        custom_objects={
                                        'CustomModel': CustomModel,                                                     
                                        'CustomLayer1Class': CustomLayer1Class,
                                        ...
                                        })

Но почему-то этот подход кажется довольно медленным ... С другой стороны, этот подход почти в 30 раз быстрее. Не уверен почему:

    model.save_weights("weights.h5")
    config = model.get_config()
    reinitialized_model = CustomModel.from_config(config)
    reinitialized_model.load_weights("weights.h5")

Я работаю, но это кажется довольно хакерским. Возможно, будущие версии TF2 сделают процесс более понятным.

0 голосов
/ 30 апреля 2019

Еще одним способом сохранения обученной модели является использование модуля pickle в python.

import pickle
pickle.dump(model, open(filename, 'wb'))

Чтобы загрузить модель pickled,

loaded_model = pickle.load(open(filename, 'rb'))

Расширение файла маринада обычно .sav

0 голосов
/ 07 мая 2019

Я нашел временное решение. Кажется, что проблема возникает с последовательным API tf.keras.Sequential, используя функциональный API, tf.keras.models.load_model удается загрузить сохраненную модель. Я надеюсь, что они исправят эту проблему в финальной версии, взгляните на проблему, которую я поднял в github https://github.com/tensorflow/tensorflow/issues/28281.

Приветствия

0 голосов
/ 30 апреля 2019

Попробуйте вместо этого сохранить модель как:

model.save('model_name.model')

Затем загрузите ее с:

model = tf.keras.models.load_model('model_name.model')
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...