Роль input_tensors в tf.keras.models.clone_model - PullRequest
0 голосов
/ 12 января 2020

Я пытаюсь скопировать существующую модель keras. Ниже приведен пример кода, который я создал, и кажется, что он работает должным образом.

model = CreateSimpleModel()
model.compile(loss="sparse_categorical_crossentropy",
                 optimizer="adam",
                 metrics=["accuracy"])

model.summary()


model_cloned = tf.keras.models.clone_model(model)
model_cloned.set_weights(model.get_weights())

print(model(np.array([[1, 2]])))
print(model_cloned(np.array([[1, 2]])))

Однако, если мы посмотрим на официальную документацию о tf.keras.models.clone_model на следующей странице, есть параметр с именем input_tensors.

https://www.tensorflow.org/api_docs/python/tf/keras/models/clone_model

Я не очень уверен в роли этого параметра. Из приведенного выше примера кода я не совсем понимаю, зачем он нужен в некоторых случаях. Может ли кто-нибудь объяснить с некоторыми примерами?

...