керас компилируется с набором данных и гибкими потерями / метриками - PullRequest
0 голосов
/ 19 февраля 2019

Я портирую кучу кода из tf.estimator.Estimator API в tf.keras, используя tf.data.Dataset s, и я надеюсь остаться как можно ближе к предоставленному compile / fit.Я разочарован loss и metrics аргументами компиляции.

По сути, я хотел бы использовать функцию потерь, которая использует несколько выходов и меток неаддитивным способом, то есть я хочупредоставить

def custom_loss(all_labels, model_outputs):
    """
    Args:
        all_labels: all labels in the dataset, as a single tensor, tuple or dict
        model_outputs: all outputs of model as a single tensor, tuple or dict

    Returns:
        single loss tensor to be averaged.
    """"
    ...

Я не могу предоставить это compile, поскольку, насколько мне известно, он поддерживает только взвешенные суммы потерь для каждого выхода / метки и делает предположения о форме каждой меткина основе соответствующей модели вывода.Я не могу создать его отдельно и использовать model.add_loss, потому что у меня никогда не будет явного доступа к тензору меток, если я хочу, чтобы model.fit обрабатывал итерацию набора данных.Я подумал о сведении / объединении всех выходов и меток вместе, но тогда я не могу контролировать несколько metrics.

Я могу написать свой собственный цикл обучения, используя model.train_on_batch, но это заставляет меня копировать поведениеуже реализовано в fit, таких как итерация набора данных, обратные вызовы, проверка, стратегии распределения и т. д.


В качестве примера я хотел бы повторить следующую оценку.

def model_fn(features, labels, mode):
    outputs = get_outputs(features)  # dict
    loss = custom_loss(labels, outputs)
    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)
    eval_metrics_op = {
      'a_mean': tf.metrics.mean(outputs['a'])
    }
    return tf.estimator.EstimatorSpec(
        loss=loss, train_op=train_op, mode=mode, eval_metric_ops=eval_metric_ops)

estimator = tf.estimator.Estimator(model_fn=model_fn)
estimator.train(dataset_fn)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...