В моем model_fn
для пользовательского Оценщик Я пытаюсь обобщить некоторые аспекты. Пока я возился с этой идеей, я натолкнулся на кое-что странное.
Если я определю EstimatorSpec и верну _not_ , он все равно будет действовать так, как если бы это было возвращено. A Colab со всем кодом доступен.
Для подтверждения концепции я только что изменил несколько строк в (buggy)_model_fn
и (buggy)_mode_predict
(код ниже, но также ознакомился с Colab ).
Почему инициализация EstimatorSpec , независимо от области действия, изменяет поведение Estimator ?
model_fn
функциональный
def model_fn(features, labels, mode, params):
MODEL = {'features': features, 'labels': labels, 'mode': mode, 'params': params}
# send the features through the graph
MODEL = build_fn(MODEL)
# prediction
MODEL['predictions'] = {'labels': MODEL['net_logits']}
MODEL['export_outputs'] = {
k: tf.estimator.export.PredictOutput(v) for k, v in MODEL['predictions'].items()
}
if mode == tf.estimator.ModeKeys.PREDICT:
return mode_predict(MODEL)
# calculate the loss
MODEL = loss_fn(MODEL)
# calculate all metrics and send them to tf.summary
MODEL = metrics_fn(MODEL)
if mode == tf.estimator.ModeKeys.EVAL:
return mode_eval(MODEL)
if mode == tf.estimator.ModeKeys.TRAIN:
return mode_train(MODEL)
багги
def buggy_model_fn(features, labels, mode, params):
MODEL = {'features': features, 'labels': labels, 'mode': mode, 'params': params}
# send the features through the graph
MODEL = build_fn(MODEL)
# prediction
# START BUGS HERE -----------------------------------------------
MODEL = buggy_mode_predict(MODEL)
if mode == tf.estimator.ModeKeys.PREDICT:
return MODEL['PREDICT_SPEC']
# END BUGS HERE -----------------------------------------------
# calculate the loss
MODEL = loss_fn(MODEL)
# calculate all metrics and send them to tf.summary
MODEL = metrics_fn(MODEL)
if mode == tf.estimator.ModeKeys.EVAL:
return mode_eval(MODEL)
if mode == tf.estimator.ModeKeys.TRAIN:
return mode_train(MODEL)
mode_predict
функциональный
def mode_predict(model):
"""How to predict given the model.
Args:
model (dict): a `dict` containing the model
Returns:
spec (`EstimatorSpec`_): Ops and objects returned from a model_fn and passed to an Estimator
.. _EstimatorSpec:
https://www.tensorflow.org/api_docs/python/tf/estimator/EstimatorSpec
"""
# do the predictions here
spec = tf.estimator.EstimatorSpec(
mode = model['mode'],
predictions = model['predictions'],
export_outputs = model['export_outputs']
)
return spec
багги
def buggy_mode_predict(model):
# do the predictions here
model['predictions'] = {'labels': model['net_logits']}
model['export_outputs'] = {
k: tf.estimator.export.PredictOutput(v) for k, v in model['predictions'].items()
}
spec = tf.estimator.EstimatorSpec(
mode = model['mode'],
predictions = model['predictions'],
export_outputs = model['export_outputs']
)
# START BUGS HERE -----------------------------------------------
model['PREDICT_SPEC'] = spec
# END BUGS HERE -----------------------------------------------
return model