Функция Keras Custom Loss для определения, находится ли изображение в квадранте - PullRequest
0 голосов
/ 06 февраля 2019

Я пытаюсь создать пользовательскую функцию потерь, которая определяет, находится ли большая часть изображения y_pred 32x32 в том же квадранте, что и изображение y_true.Моя идея состояла в том, чтобы создать маску 32x32 для каждого квадранта и умножить входные изображения на эти маски, затем сложить маски и найти максимальную сумму для идентификации квадранта.Если бы максимальная сумма была связана с одним и тем же квадрантом, то это идеальное совпадение, иначе я бы вычислил дисперсию.

Я пробовал следующее (среди многих других вещей):

arrtl = np.zeros((32, 32, 1))
arrtl[0:16, 0:16] = 1

arrtr = np.zeros((32, 32, 1))
arrtr[0:16, 16:]

arrbl = np.zeros((32, 32, 1))
arrbl[16:, 0:16]

arrbr = np.zeros((32, 32, 1))
arrbr[16:, 16:]

def custom_loss(y_true, y_pred):
    kernels = [K.variable(arrtl), K.variable(arrtr), K.variable(arrbl), K.variable(arrbr)]
    true_quadrant = kernels* y_true
    true_quadrant = K.map_fn(K.sum, true_quadrant)
    true_max = K.max(true_quadrant)
    true_max_q = K.tf.where(K.tf.equal(true_quadrant, 1))

    pred_quadrant = y_pred * kernels
    pred_quadrant = K.map_fn(K.sum, pred_quadrant)
    pred_max = K.max(pred_quadrant)
    pred_max_q = K.tf.where(K.tf.equal(pred_quadrant, pred_max))

    res = K.abs(true_max_q - pred_max_q)
    return res

Этот код намеревается:

  1. Создать маску для каждого квадранта (например, arrtl - это массив для верхней левой маски)
  2. Умножить 32x32 изображения на маски для извлечениятолько данные в каждом квадранте.
  3. Суммируйте данные в каждом квадранте, чтобы увидеть, какой квадрант содержит наибольшее количество данных.
  4. Выберите индекс максимального квадранта (квадрант 0-3)
  5. Выполните вышеуказанные 4 шага как для y_pred, так и для y_true, вычтя y_pred из квадрантов y_true в качестве результата потери.

Я знаю, что это, вероятно, дико далеко от правильной реализации, но яне могу найти какой-либо ресурс, который помог бы мне.Это часть тестовой сети глубокого обучения, в которой я пытаюсь понять, как сделать сложные функции потери.Я знаю, что могу заставить сеть обнаруживать и генерировать контент, создавая свои слои вне функции потерь, но я пытаюсь узнать, как работают функции потерь.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...