Внедрение глобальной потери триплета в тензорном потоке - PullRequest
0 голосов
/ 04 мая 2019

Во-первых, спасибо, что прочитали мой вопрос.Я пытаюсь реализовать глобальную потерю триплетов, описанную в этой статье (https://arxiv.org/pdf/1512.09272.pdf) в моем проекте. Основная идея этой функции потерь может быть суммирована как:

"... 1) свернутьдисперсия двух распределений и среднее значение расстояний между совпадающими парами, и 2) максимизируют среднее значение расстояний между несовпадающими парами ... "

Для моей текущей функции потери триплета я простоминимизировать расстояние между партиями совпадающих пар и максимально увеличить расстояние между партиями несоответствующих пар.Для глобальной потери триплета мне нужно вести учет двух расстояний и вычислять дисперсию и среднее двух распределений для функции потерь.Насколько я понимаю, дисперсия и среднее значение должны быть рассчитаны по всем предыдущим данным обучения.

В настоящее время я могу напечатать соответствующие расстояния, используя следующий неудовлетворительный код.

def global_triplet_loss():
    anchor = anchor_positive_negative_tensor[:, :, 0]
    positive = anchor_positive_negative_tensor[:, :, 1]
    negative = anchor_positive_negative_tensor[:, :, 2]

    Dp = K.sqrt(K.sum(K.square((anchor - positive)), axis=1))
    Dn = K.sqrt(K.sum(K.square((anchor - negative)), axis=1))

    Dp1 = K.mean(Dp)
    Dp1 = K.print_tensor(Dp1, message='mean dist between A and P = ')

    Dn1 = K.mean(Dn)
    Dn1 = K.print_tensor(Dn1, message='mean dist between A and N = ')

    b = Dp - Dn + Dn1 - Dn1 + Dp1 - Dp1

    return K.maximum(0.0, 1 + K.mean(b))

Я пробовал несколько вещей, таких как преобразование значения тензора расстояния в целое число и добавление его во внешний список, который я мог бы использовать для хранения расстояний и выполнения соответствующих вычислений в функции потерь.Любые идеи о том, как реализовать это, были бы очень признательны, все, что я сделал до сих пор, было полностью неэффективным, медленным и неэффективным.

...