У меня есть пользовательская потеря, которая использует один из входов для модели.
def closs(labels,latent_dim):
def loss(y_true,y_pred):
return metric_learning.contrastive_loss(labels=labels,
embeddings_anchor=y_pred[:,:latent_dim],
embeddings_positive=y_pred[:,latent_dim:])
return loss
Где метки - это входные данные для модели. Архитектура модели:
def build_model():
left_input = Input(shape=(2900,1))
right_input = Input(shape=(2900,1))
label = Input(shape=(1,))
encoder = build_encoder()
left_embed = encoder(left_input)
right_embed = encoder(right_input)
embeds = Concatenate()([left_embed,right_embed])
model = Model(inputs=[left_input,right_input,label],outputs=[embeds])
return model, label
Затем я использую возвращенную метку для компиляции модели:
model,label = build_model()
model.compile(optimizer='adam',loss=closs(label,256))
Но когда я пытаюсь загрузить модель, я должен пройти это потеря как custom_object, так что-то вроде этого:
model = load_model('model/cl_model.h5',custom_objects={'loss':closs(xyz,256)})
Проблема в том, что я загружаю модель в другой сценарий python, и поэтому у меня нет входного объекта "label" , Как я могу преодолеть это?