Я пытаюсь скопировать существующую модель keras. Ниже приведен пример кода, который я создал, и кажется, что он работает должным образом.
model = CreateSimpleModel()
model.compile(loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])
model.summary()
model_cloned = tf.keras.models.clone_model(model)
model_cloned.set_weights(model.get_weights())
print(model(np.array([[1, 2]])))
print(model_cloned(np.array([[1, 2]])))
Однако, если мы посмотрим на официальную документацию о tf.keras.models.clone_model
на следующей странице, есть параметр с именем input_tensors
.
https://www.tensorflow.org/api_docs/python/tf/keras/models/clone_model
Я не очень уверен в роли этого параметра. Из приведенного выше примера кода я не совсем понимаю, зачем он нужен в некоторых случаях. Может ли кто-нибудь объяснить с некоторыми примерами?