Tensorflow one custom metri c для моделей с несколькими выходами - PullRequest
1 голос
/ 29 января 2020

Я не могу найти информацию в документации, поэтому спрашиваю здесь.

У меня есть модель с несколькими выходами с 3 различными выходами:

model = tf.keras.Model(inputs=[input], outputs=[output1, output2, output3])

Предсказанные метки для проверки: построенный из этих 3 выходов, чтобы сформировать только один, это шаг постобработки. Набор данных, используемый для обучения, представляет собой набор данных из этих 3 промежуточных выходных данных, для проверки я оцениваю набор данных из меток вместо 3-х типов промежуточных данных.

Я хотел бы оценить свою модель с использованием пользовательского показателя c, которые обрабатывают пост-обработку и сравнение с основной истиной.

Мой вопрос , в коде пользовательского метри c, будет y_pred список списка 3 выхода модели?

class MyCustomMetric(tf.keras.metrics.Metric):

  def __init__(self, name='my_custom_metric', **kwargs):
    super(MyCustomMetric, self).__init__(name=name, **kwargs)

  def update_state(self, y_true, y_pred, sample_weight=None):
    # ? is y_pred a list [batch_output_1, batch_output_2, batch_output_3] ? 

  def result(self):
    pass 

# one single metric handling the 3 outputs?
model.compile(optimizer=tf.compat.v1.train.RMSPropOptimizer(0.01),
              loss=tf.keras.losses.categorical_crossentropy,
              metrics=[MyCustomMetric()])

1 Ответ

1 голос
/ 30 января 2020

С данным определением модели это стандартная модель с несколькими выходами.

model = tf.keras.Model(inputs=[input], outputs=[output_1, output_2, output_3])

В общем, все (пользовательские) метрики, а также (пользовательские) потери будут вызываться на каждом выходе отдельно ( как у_пред)! В функции loss / metri c вы увидите только один выход вместе с одним соответствующим целевым тензором. Передав список функций потерь (длина == количество выходов вашей модели), вы можете указать, какие потери будут использоваться для какого выхода:

model.compile(optimizer=Adam(), loss=[loss_for_output_1, loss_for_output_2, loss_for_output_3], loss_weights=[1, 4, 8])

Общая потеря (которая является целевой функцией для минимизации). ) будет аддитивной комбинацией всех потерь, умноженных на данные веса потерь.

Это почти то же самое для метрик! Здесь вы можете передать (что касается потери) список (длина == количество выходов) метрик и сообщить Keras, какой metri c использовать для какой из ваших моделей выводит.

model.compile(optimizer=Adam(), loss='mse', metrics=[metrics_for_output_1, metrics_for_output2, metrics_for_output3])

Здесь metrics_for_output_X может быть либо функцией, либо списком функций, которые все вызываются с одной соответствующей output_X как y_pred.

Это подробно объясняется в документации моделей с несколькими выходами в Keras. Они также показывают примеры использования словарей (чтобы отобразить функции loss / metri c на конкретный вывод c) вместо списков. https://keras.io/getting-started/functional-api-guide/#multi модели с входным и несколькими выходами

Дополнительная информация:

Если я вас правильно понимаю, вы хотите обучить своих В модели используется функция потерь, сравнивающая три выходных сигнала модели с тремя основными значениями истинности, и требуется выполнить некоторую оценку производительности путем сравнения производного значения из трех выходных данных модели и единственного основного значения истинности. Обычно модель обучается с той же целью, по которой она оценивается, в противном случае вы можете получить худшие результаты при оценке вашей модели!

В любом случае ... для оценки вашей модели по одной метке, я предлагаю вам либо:

1. (Чистое решение)

Перепишите модель и включите этапы последующей обработки. Добавьте все необходимые операции (как слои) и сопоставьте их со вспомогательным выходом. Для обучения вашей модели вы можете установить значение loss_weight вспомогательного выхода равным нулю. Объедините ваши наборы данных, чтобы вы могли подать в вашу модель ввод модели, промежуточные выходные результаты, а также метки. Как объяснено выше, теперь вы можете определить метрику c, сравнивая выходные данные вспомогательной модели с заданными целевыми метками.

2.

Или вы тренируете свою модель и выводите Метри c Например, в пользовательском обратном вызове путем расчета ваших шагов пост-обработки на трех выходных данных model.predict (вход) Это потребует написания пользовательских сводок, если вы хотите отслеживать эти значения в своей тензорной доске! Вот почему я бы не рекомендовал это решение.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...