Пользовательские оценщики TensorFlow: определение ошибки триггеров спецификации оценщика - PullRequest
0 голосов
/ 15 ноября 2018

В моем model_fn для пользовательского Оценщик Я пытаюсь обобщить некоторые аспекты. Пока я возился с этой идеей, я натолкнулся на кое-что странное.

Если я определю EstimatorSpec и верну _not_ , он все равно будет действовать так, как если бы это было возвращено. A Colab со всем кодом доступен.

Для подтверждения концепции я только что изменил несколько строк в (buggy)_model_fn и (buggy)_mode_predict (код ниже, но также ознакомился с Colab ).

Почему инициализация EstimatorSpec , независимо от области действия, изменяет поведение Estimator ?

model_fn

функциональный

def model_fn(features, labels, mode, params):
    MODEL = {'features': features, 'labels': labels, 'mode': mode, 'params': params}

    # send the features through the graph
    MODEL = build_fn(MODEL)

    # prediction
    MODEL['predictions'] = {'labels': MODEL['net_logits']}

    MODEL['export_outputs'] = {
        k: tf.estimator.export.PredictOutput(v) for k, v in MODEL['predictions'].items()
    }


    if mode == tf.estimator.ModeKeys.PREDICT: 
      return mode_predict(MODEL)

    # calculate the loss
    MODEL = loss_fn(MODEL)

    # calculate all metrics and send them to tf.summary
    MODEL = metrics_fn(MODEL)

    if mode == tf.estimator.ModeKeys.EVAL: 
      return mode_eval(MODEL)

    if mode == tf.estimator.ModeKeys.TRAIN: 
      return mode_train(MODEL)

багги

def buggy_model_fn(features, labels, mode, params):
    MODEL = {'features': features, 'labels': labels, 'mode': mode, 'params': params}

    # send the features through the graph
    MODEL = build_fn(MODEL)



    # prediction
    # START BUGS HERE -----------------------------------------------
    MODEL = buggy_mode_predict(MODEL)
    if mode == tf.estimator.ModeKeys.PREDICT:
      return MODEL['PREDICT_SPEC']
    # END BUGS HERE -----------------------------------------------



    # calculate the loss
    MODEL = loss_fn(MODEL)

    # calculate all metrics and send them to tf.summary
    MODEL = metrics_fn(MODEL)

    if mode == tf.estimator.ModeKeys.EVAL: 
      return mode_eval(MODEL)

    if mode == tf.estimator.ModeKeys.TRAIN: 
      return mode_train(MODEL)

mode_predict

функциональный

def mode_predict(model):
    """How to predict given the model.

    Args:
        model (dict): a `dict` containing the model

    Returns:
        spec (`EstimatorSpec`_): Ops and objects returned from a model_fn and passed to an Estimator

    .. _EstimatorSpec:
        https://www.tensorflow.org/api_docs/python/tf/estimator/EstimatorSpec

    """
    # do the predictions here

    spec = tf.estimator.EstimatorSpec(
        mode           = model['mode'],
        predictions    = model['predictions'],
        export_outputs = model['export_outputs']
    )
    return spec

багги

def buggy_mode_predict(model):
    # do the predictions here
    model['predictions'] = {'labels': model['net_logits']}

    model['export_outputs'] = {
        k: tf.estimator.export.PredictOutput(v) for k, v in model['predictions'].items()
    }

    spec = tf.estimator.EstimatorSpec(
        mode           = model['mode'],
        predictions    = model['predictions'],
        export_outputs = model['export_outputs']
    )
    # START BUGS HERE -----------------------------------------------
    model['PREDICT_SPEC'] = spec
    # END BUGS HERE -----------------------------------------------
    return model
...