Как сделать потерю перекрестной энтропии (DCE) для классификации в Керасе? - PullRequest
0 голосов
/ 04 марта 2020

Кросс-энтропийные потери на основе расстояния могут быть описаны как:

enter image description here

Орудие Tensorflow выглядит следующим образом:

def mnist_net(images):
    inputs = tf.transpose(images, perm=[0,2,3,1])

    conv1 = Conv(inputs, [5,5,1,32], activation=ReLU)

    conv2 = Conv(conv1, [5,5,32,32], activation=ReLU)

    pool1 = Max_pool(conv2, padding='VALID')

    conv3 = Conv(pool1, [5,5,32,64], activation=ReLU)

    conv4 = Conv(conv3, [5,5,64,64], activation=ReLU)

    pool2 = Max_pool(conv4, padding='VALID')

    conv5 = Conv(pool2, [5,5,64,128], activation=ReLU)

    conv6 = Conv(conv5, [5,5,128,128], activation=ReLU)

    pool3 = Max_pool(conv6, padding='VALID')

    fc1 = FC(tf.reshape(pool3, [-1, 3*3*128]), 3*3*128, 2)

    return fc1 

def construct_center(features, num_classes):
    len_features = features.get_shape()[1]
    centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32,
        initializer=tf.constant_initializer(0))

    return centers

def dce_loss(features, labels, centers, T):
    dist = distance(features, centers)
    logits = -dist / T
    mean_loss = softmax_loss(logits, labels)

    return mean_loss

def distance(features, centers):
    f_2 = tf.reduce_sum(tf.pow(features, 2), axis=1, keep_dims=True)
    c_2 = tf.reduce_sum(tf.pow(centers, 2), axis=1, keep_dims=True)
    dist = f_2 - 2*tf.matmul(features, centers, transpose_b=True) + tf.transpose(c_2, perm=[1,0])

    return dist

def softmax_loss(logits, labels):
    labels = tf.to_int32(labels)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
        logits=logits, name='xentropy')

    return tf.reduce_mean(cross_entropy, name='xentropy_mean')

    features = mnist_net(images) #features.shape=[batch_size, len_features]
    centers = construct_center(features, FLAGS.num_classes) #centers.shape = [num_classes, len_features]
    loss = dce_loss(features, labels, centers, gamma)

Обратите внимание, что центры (прототипы) инициализируются как нулевые векторы, а затем автоматически обновляются с помощью процесса тонкой настройки сети 'mnist _net'. Теперь вопрос в том, как реализовать эту нестандартную потерю в Керасе?

...