Как установить потерю для train_op, используя tf estimator API, когда потеря зависит от ввода и реализуется через пользовательский уровень потерь? - PullRequest
1 голос
/ 23 апреля 2019

Я внедряю модель CNN с tf.estimator API.

Вкратце, функция затрат / потерь зависит от входных данных и содержит переменные, которые должны обновляться во время обучения. Я реализовал пользовательский слой с идентичным отображением только для вызова add_loss() и удержания обучаемых переменных для потери.

Однако, когда я попытался указать train_op для EstimatorSpec (используя AdamOptimizer), я понятия не имел, как извлечь потери и скорректировать их до optimizer.minimize().

Проблема возникла у меня, когда я пытался реализовать потерю с мульти-максимальным гауссовым вероятностью, предложенную в Kendall_CVPR2018 . Я принял общий подход в качестве примера кода, предоставленного автором статьи на Github , который определяет пользовательский слой для потери.

К сожалению, вышеупомянутый код использует Keras, в то время как я пытаюсь протестировать его с tensorflow, точнее, tf.estimator. В Keras при вызове model.compile () мы можем указать None в качестве аргумента loss. Но я полагаю, что мы не можем передать None оптимизатору в tenorflow.

def model_fn(features, labels, mode, params):
    ...
    xs = ts.reshape(xs, shape=[-1, ...])
    nn_params = dict(...)
    ys_out = cnn_blabla(x, mode, ** nn_params)
    ...

    loss=???

    ...
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.AdamOptimmizer(params['LEARNING_RATE'])
        train_op = optimizer.minimize(loss)
    ...
    return tf.estimator.EstimatorSpec(...)


def cnn_blabla(x, mode, n_outputs, ...):
    with tf.variable_scope("blabla", reuse=tf.AUTO_REUSE):
        layer_out_1 = conv(x, ..., activation=..., name=...)
        layer_out_2 = conv(layer_out_1, ..., activation=..., name=...)
        ...
        layer_out_v = conv(layer_out_u, ..., activation=..., name=...)
        out = CustomLossLayer(n_outputs=n_outputs, name="loss_blabla")(layer_out_v)
    return out


def conv(...):
    ...

Я ожидаю обучить модель с пользовательской потерей, через tf.estimator в тензорном потоке.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...