Итак, я реализую потери центра: https://ydwen.github.io/papers/WenECCV16.pdf и у меня возникла проблема с обновлением весов в моем слое, что здесь означает обновление центров в потере центра. Когда я печатаю свои class_centers, как это tf.print(self.class_centers, summarize=-1, output_stream='file:///tensors.txt')
, они никогда не меняются. Когда я печатаю другие переменные, они кажутся нормальными, поэтому я могу думать только о том, что add_update () не делает то, что должен.
Пользовательский слой:
class CenterLossLayer(Layer):
def __init__(self, alpha=0.5, **kwargs):
self.alpha = alpha
super(CenterLossLayer, self).__init__(**kwargs)
def build(self, input_shape):
print('Center loss input 1 (feature_size): ', input_shape[0][1])
print('Center loss input 2 (num_classes): ', input_shape[1][1])
self.class_centers = self.add_weight(name='class_centers',
shape=(input_shape[1][1], input_shape[0][1]),
initializer='uniform',
trainable=False)
super(CenterLossLayer, self).build(input_shape)
def call(self, x, mask=None):
embeddings, one_hots = x
tf.print(self.class_centers, summarize=-1, output_stream='file:///tensors.txt')
batch_centers = K.dot(one_hots, self.class_centers)
batch_delta = batch_centers - embeddings
class_delta = K.dot(K.transpose(one_hots), batch_delta)
counts = K.sum(K.transpose(one_hots), axis=1, keepdims=True) + 1
class_delta = class_delta / counts
class_delta = K.in_train_phase(self.alpha * class_delta, 0 * class_delta)
updated_class_centers = self.class_centers - class_delta
self.add_update((self.class_centers, updated_class_centers), x[0])
losses = K.sum(K.square(embeddings - batch_centers), axis=1, keepdims=True)
return losses
def compute_output_shape(self, input_shape):
return (input_shape[1][0], )
и окончательная потеря:
def batch_mean_loss(y_true, y_pred):
return K.mean(y_pred, axis=0)
, где y_pred
равно losses
от CenterLossLayer.
Странная вещь в том, что даже если центры не обновляются, потеря центра уменьшается с каждой эпохой, и окончательная модель лучше, чем та, которая тренировалась только с потерей Softmax.