У меня есть класс tf.Module
, который содержит (не может быть выбран) tf.keras.Model
в качестве подмодуля. Интересно, каков рекомендуемый способ сериализации tf.Module
в этом случае?
Я рассмотрел два способа:
- Использование чего-то похожего на
tf.keras.Model.save
. Я надеялся, что, возможно, tf.Module
s сможет сохранить вложенные модули так же, как tf.Model.save
. Однако tf.Module
не имеет такой возможности. - Pickling, что было бы простым способом сериализации
tf.Module
, но я не могу этого сделать, потому что tf.keras.Model
не выбирается.
Вот пример кода, который в настоящее время не работает:
import pickle
import tensorflow as tf
class TestModule(tf.Module):
def __init__(self, model):
self.model = model
def main():
x = tf.keras.layers.Input((3, ))
y = tf.keras.layers.Dense(5)(x)
# Note, model *is not* picklable.
model = tf.keras.Model(x, y)
_ = model(tf.random.uniform((1, 3)))
module_1 = TestModule(model)
module_2 = pickle.loads(pickle.dumps(module_1))
for variable_1, variable_2 in zip(module_1.model.trainable_variables,
module_2.model.trainable_variables):
tf.debugging.assert_equal(variable_1, variable_2)
if __name__ == '__main__':
main()
Должен ли я написать пользовательские функции выбора (например, __{get,set}state__
) для каждого tf.Module
или создать аналогичный .save
метод, который есть у keras.Model
?