Как метрики должны быть добавлены к многоголовному оценщику TensorFlow? - PullRequest
0 голосов
/ 24 января 2020

Ранее я создавал метрики для классификатора TensorFlow, ссылаясь на predictions['logits'] для вычисления метрик. Я изменил модель с классификатора на оценщик, чтобы включить многоцелевое обучение (используя MultiHead). Однако это привело к ошибке Python, поскольку теперь элементы predictions вводятся с помощью пар имени заголовка и исходного ключа, например ('label1', 'logits') для заголовка с именем 'label1' .

Я хотел бы разрешить динамическое создание метрики c на основе файла конфигурации, чтобы упростить обучение и тестирование различных моделей с различными комбинациями меток. Теперь проблема заключается в том, что параметр metric_fn для tf.estimator.add_metrics не принимает никаких дополнительных параметров для учета динамически определенных или построенных метрик.

Как можно создать оценщик с несколькими головками и пользовательскими метриками для каждой головы

1 Ответ

1 голос
/ 24 января 2020

Создайте класс вокруг создания модели, который содержит конфигурацию модели, и используйте функцию-член для параметра metric_fn.

class ModelBuilder: # constructor storing configuration options in self def __init__(self, labels, other_config_args): self.labels = labels ... # Function for building the estimator with multiple heads (multi-objective) def build_estimator(self, func_args): heads = [] for label in self.labels: heads.append(tf.estimator.MultiClassHead(n_classes=self.nclasses[label], name=label)) head = tf.estimator.MultiHead(heads) estimator = tf.estimator.DNNEstimator(head=heads,...) # or whatever type of estimator you want estimator = tf.estimator.add_metrics(estimator, self.model_metrics) return estimator # Member function that adds metrics to the estimator based on the model configuration params def model_metrics(self, labels, predictions, features): metrics = {} for label in self.labels: # generate a metric for each head name metrics['metric_name'] = metric_func(features,labels,predictions[(label,'logits')]) return metrics

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...