Создайте класс вокруг создания модели, который содержит конфигурацию модели, и используйте функцию-член для параметра metric_fn
.
class ModelBuilder:
# constructor storing configuration options in self
def __init__(self, labels, other_config_args):
self.labels = labels
...
# Function for building the estimator with multiple heads (multi-objective)
def build_estimator(self, func_args):
heads = []
for label in self.labels:
heads.append(tf.estimator.MultiClassHead(n_classes=self.nclasses[label], name=label))
head = tf.estimator.MultiHead(heads)
estimator = tf.estimator.DNNEstimator(head=heads,...) # or whatever type of estimator you want
estimator = tf.estimator.add_metrics(estimator, self.model_metrics)
return estimator
# Member function that adds metrics to the estimator based on the model configuration params
def model_metrics(self, labels, predictions, features):
metrics = {}
for label in self.labels: # generate a metric for each head name
metrics['metric_name'] = metric_func(features,labels,predictions[(label,'logits')])
return metrics