Кросс-энтропийные потери на основе расстояния могут быть описаны как:
data:image/s3,"s3://crabby-images/e230d/e230d705e879f30bcf5114af57e28931dc7b045e" alt="enter image description here"
Орудие Tensorflow выглядит следующим образом:
def mnist_net(images):
inputs = tf.transpose(images, perm=[0,2,3,1])
conv1 = Conv(inputs, [5,5,1,32], activation=ReLU)
conv2 = Conv(conv1, [5,5,32,32], activation=ReLU)
pool1 = Max_pool(conv2, padding='VALID')
conv3 = Conv(pool1, [5,5,32,64], activation=ReLU)
conv4 = Conv(conv3, [5,5,64,64], activation=ReLU)
pool2 = Max_pool(conv4, padding='VALID')
conv5 = Conv(pool2, [5,5,64,128], activation=ReLU)
conv6 = Conv(conv5, [5,5,128,128], activation=ReLU)
pool3 = Max_pool(conv6, padding='VALID')
fc1 = FC(tf.reshape(pool3, [-1, 3*3*128]), 3*3*128, 2)
return fc1
def construct_center(features, num_classes):
len_features = features.get_shape()[1]
centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32,
initializer=tf.constant_initializer(0))
return centers
def dce_loss(features, labels, centers, T):
dist = distance(features, centers)
logits = -dist / T
mean_loss = softmax_loss(logits, labels)
return mean_loss
def distance(features, centers):
f_2 = tf.reduce_sum(tf.pow(features, 2), axis=1, keep_dims=True)
c_2 = tf.reduce_sum(tf.pow(centers, 2), axis=1, keep_dims=True)
dist = f_2 - 2*tf.matmul(features, centers, transpose_b=True) + tf.transpose(c_2, perm=[1,0])
return dist
def softmax_loss(logits, labels):
labels = tf.to_int32(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
logits=logits, name='xentropy')
return tf.reduce_mean(cross_entropy, name='xentropy_mean')
features = mnist_net(images) #features.shape=[batch_size, len_features]
centers = construct_center(features, FLAGS.num_classes) #centers.shape = [num_classes, len_features]
loss = dce_loss(features, labels, centers, gamma)
Обратите внимание, что центры (прототипы) инициализируются как нулевые векторы, а затем автоматически обновляются с помощью процесса тонкой настройки сети 'mnist _net'. Теперь вопрос в том, как реализовать эту нестандартную потерю в Керасе?