Каков рекомендуемый способ сериализации `tf.Module`s? - PullRequest
2 голосов
/ 17 января 2020

У меня есть класс tf.Module, который содержит (не может быть выбран) tf.keras.Model в качестве подмодуля. Интересно, каков рекомендуемый способ сериализации tf.Module в этом случае?

Я рассмотрел два способа:

  1. Использование чего-то похожего на tf.keras.Model.save. Я надеялся, что, возможно, tf.Module s сможет сохранить вложенные модули так же, как tf.Model.save. Однако tf.Module не имеет такой возможности.
  2. Pickling, что было бы простым способом сериализации tf.Module, но я не могу этого сделать, потому что tf.keras.Model не выбирается.

Вот пример кода, который в настоящее время не работает:

import pickle

import tensorflow as tf


class TestModule(tf.Module):
    def __init__(self, model):
        self.model = model


def main():
    x = tf.keras.layers.Input((3, ))
    y = tf.keras.layers.Dense(5)(x)
    # Note, model *is not* picklable.
    model = tf.keras.Model(x, y)

    _ = model(tf.random.uniform((1, 3)))

    module_1 = TestModule(model)
    module_2 = pickle.loads(pickle.dumps(module_1))

    for variable_1, variable_2 in zip(module_1.model.trainable_variables,
                                      module_2.model.trainable_variables):
        tf.debugging.assert_equal(variable_1, variable_2)


if __name__ == '__main__':
    main()

Должен ли я написать пользовательские функции выбора (например, __{get,set}state__) для каждого tf.Module или создать аналогичный .save метод, который есть у keras.Model?

1 Ответ

0 голосов
/ 18 января 2020

Вы можете использовать Формат сохраненной модели для сохранения пользовательского подкласса tf.Module.

Для Tensorflow 2.1 работает следующее:

import tensorflow as tf

class TestModule(tf.Module):
    def __init__(self, model):
        self.model = model


x = tf.keras.layers.Input((3, ))
y = tf.keras.layers.Dense(5)(x)
model = tf.keras.Model(x, y)
module_1 = TestModule(model)

tf.saved_model.save(module_1, "./foo")

Чтобы загрузить обратно:
imported = tf.saved_model.load("foo")

Утверждение
module_1 == imported (или подобное) повысит AssertionError, так как после загрузки мы имеем дело с другим объектом Tensorflow. Однако мы можем перебрать веса модели и сравнить их поэлементно:

original_weights = module_1.model.weights
imported_weights = imported.model.variables.weights

for weight_idx, _ in enumerate(original_weights):
  assert (
      original_weights[weight_idx].numpy() == imported_weights[weight_idx].numpy()
      ).all()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...