Как убрать дублирование этих определений keras metri c? - PullRequest
1 голос
/ 05 февраля 2020

Keras предоставляет метрики точности, точности и отзыва, которые можно использовать для оценки вашей модели, но эти метрики могут оценивать только все y_true и y_pred. Я хочу, чтобы он оценил только подмножество данных. y_true[..., 0:20] в моих данных содержит двоичные значения, которые я хочу оценить, но y_true[..., 20:40] содержит данные другого типа.

Поэтому я изменил классы точности и отзыва, чтобы оценивать только по первым 20 каналам моих данных , Я сделал это, создав подклассы этих метрик и попросив их нарезать данные перед оценкой.

from tensorflow import keras as kr

class SliceBinaryAccuracy(kr.metrics.BinaryAccuracy):
    """Slice data before evaluating accuracy. To be used as Keras metric"""

    def __init__(self, channels, *args, **kwargs):
        self.channels = channels
        super().__init__(*args, **kwargs)

    def _slice(self, y):
        return y[..., : self.channels]

    def __call__(self, y_true, y_pred, *args, **kwargs):
        y_true = self._slice(y_true)
        y_pred = self._slice(y_pred)
        return super().__call__(y_true, y_pred, *args, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = self._slice(y_true)
        y_pred = self._slice(y_pred)
        super().update_state(y_true, y_pred, sample_weight=sample_weight)


class SlicePrecision(kr.metrics.Precision):
    """Slice data before evaluating precision. To be used as Keras metric"""

    def __init__(self, channels, *args, **kwargs):
        self.channels = channels
        super().__init__(*args, **kwargs)

    def _slice(self, y):
        return y[..., : self.channels]

    def __call__(self, y_true, y_pred, *args, **kwargs):
        y_true = self._slice(y_true)
        y_pred = self._slice(y_pred)
        return super().__call__(y_true, y_pred, *args, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = self._slice(y_true)
        y_pred = self._slice(y_pred)
        super().update_state(y_true, y_pred, sample_weight=sample_weight)


class SliceRecall(kr.metrics.Recall):
    """Slice data before evaluating recall. To be used as Keras metric"""

    def __init__(self, channels, *args, **kwargs):
        self.channels = channels
        super().__init__(*args, **kwargs)

    def _slice(self, y):
        return y[..., : self.channels]

    def __call__(self, y_true, y_pred, *args, **kwargs):
        y_true = self._slice(y_true)
        y_pred = self._slice(y_pred)
        return super().__call__(y_true, y_pred, *args, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = self._slice(y_true)
        y_pred = self._slice(y_pred)
        super().update_state(y_true, y_pred, sample_weight=sample_weight)

Способ использования вышеуказанных классов выглядит следующим образом:

model.compile('adam', loss='mse', metrics=[SliceBinaryAccuracy(20), SlicePrecision(20), SliceRecall(20)])

Код работает, но Я обнаружил, что код довольно длинный. Я вижу много дубликатов из этих 3 метрик, как я могу обобщить эти классы в один класс или что-то лучше дизайн? Пожалуйста, приведите пример кода, если это возможно.

1 Ответ

1 голос
/ 12 февраля 2020

Я согласен, что в этих классах слишком много повторений, единственное различие между ними - это метри c, которые они подклассируют. Я думаю, что это хороший случай, чтобы применить какую-то фабричную модель. У меня есть небольшая функция, которую я создал для динамического подкласса метрик.

def MetricFactory(cls, channels):
  '''Takes a keras metric class and channels value and returns the instantiated subclassed metric'''

  class DynamicMetric(cls):
    def __init__(self, channels, *args, **kwargs):
      self.channels = channels
      super().__init__(*args, **kwargs)

    def _slice(self, y):
      return y[..., : self.channels]

    def __call__(self, y_true, y_pred, *args, **kwargs):
      y_true = self._slice(y_true)
      y_pred = self._slice(y_pred)
      return super().__call__(y_true, y_pred, *args, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
      y_true = self._slice(y_true)
      y_pred = self._slice(y_pred)
      super().update_state(y_true, y_pred, sample_weight=sample_weight)

  x = DynamicMetric(channels)
  return x

Тогда вы можете использовать ее следующим образом:

metrics = [MetricFactory(kr.metrics.BinaryAccuracy, 20), MetricFactory(kr.metrics.Precision, 20), MetricFactory(kr.metrics.Recall, 20)]
model.compile('adam', loss='mse', metrics=metrics)

Поскольку переписанные методы в точности совпадают для трех метрик, которые вы подклассифицируете, функция может внедрить их непосредственно в новый класс. Функция для простоты возвращает созданный экземпляр подкласса, но вместо этого вы можете вернуть новый класс. Стоит отметить, что этот конкретный подход не сработал бы, если бы вам пришлось передавать методы, которые вы хотите перезаписать, в качестве параметров, и, вероятно, потребовалось бы использовать метаклассы или чудесные черные маги c в строках этого потока,

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...