Сиамская сеть в наборе данных MNIST не проходит обучение - PullRequest
0 голосов
/ 21 апреля 2020

Я обучаю сиамскую сеть с конструктивной потерей на двух классах набора данных MNIST, чтобы определить, являются ли два изображения похожими или нет. Хотя потери вначале уменьшаются, позже они замирают с точностью около 0,5.

Модель обучена на парах изображений и метки (0,0 для разных, 1,0 для одинаковых). Я использовал только два класса для простоты (нули и единицы) и подготовил набор данных, чтобы он содержал каждую пару изображений. Я проверил, что набор данных согласован ( пары изображений из набора данных ). Я также безуспешно экспериментировал с нормализацией данных, различными размерами пакетов, скоростями обучения, инициализацией и константами регуляризации.

Это модель:

class Encoder(Model):
    """
    A network that finds a 50-dimensional representation of the input images
    so that the distances between them minimize the constructive loss
    """

    def __init__(self):
        super(Encoder, self).__init__(name='encoder')

        self.cv = Conv2D(32, (3, 3), activation='relu', padding='Same',
                         input_shape=(28, 28, 1),
                         kernel_regularizer=tf.keras.regularizers.l2(0.01))
        self.pool = MaxPooling2D((2, 2))
        self.flatten = Flatten()
        self.dense = Dense(50, activation=None,
                           kernel_regularizer=tf.keras.regularizers.l2(0.01))

    def call(self, inputs, training=None, mask=None):
        """ Forward pass for one image """
        x = self.cv(inputs)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.dense(x)
        return x

    @staticmethod
    def distance(difference):
        """ The D function from the paper which is used in loss """
        distance = tf.sqrt(tf.reduce_sum(tf.pow(difference, 2), 0))
        return distance

Потеря и точность:

def simnet_loss(target, x1, x2):
    difference = x1 - x2
    distance_vector = tf.map_fn(lambda x: Encoder.distance(x), difference)
    loss = tf.map_fn(lambda distance: target * tf.square(distance) +
                                      (1.0 - target) * tf.square(tf.maximum(0.0, 1.0 - distance)), distance_vector)
    average_loss = tf.reduce_mean(loss)
    return average_loss

def accuracy(y_true, y_pred):
    distance_vector = tf.map_fn(lambda x: Encoder.distance(x), y_pred)
    accuracy = tf.keras.metrics.binary_accuracy(y_true, distance_vector)
    return accuracy

Обучение:

def train_step(images, labels):
    with tf.GradientTape() as tape:
        x1, x2 = images[:, 0, :, :, :], images[:, 1, :, :, :]
        x1 = model(x1)
        x2 = model(x2)
        loss = simnet_loss(labels, x1, x2)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    return loss

model = Encoder()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

for epoch in range(n_epoch):
    epoch_loss = 0
    n_batches = int(x_train.shape[0]/batch_size)
    for indices in np.array_split(np.arange(x_train.shape[0]), indices_or_sections=n_batches):
        x = np.take(x_train, indices, axis=0)
        y = np.take(y_train, indices, axis=0)
        epoch_loss += train_step(x, y)

    epoch_loss = epoch_loss / n_batches
    accuracy = test_step(x_train, y_train)
    val_accuracy = test_step(x_test, y_test)
    tf.print("epoch:", epoch, "loss:", epoch_loss, "accuracy:", accuracy,
             "val_accuracy:", val_accuracy, output_stream=sys.stdout)

Приведенный выше код дает:

эпоха: 0 потеря: 0,755419433 Точность: 0,318898171 val_accuracy: 0.310316473

эпоха: 1 потеря: 0.270610392 точность: 0.369466901 val_accuracy: 0.360871345

e: точка: 0,262594223 точность: 0,430587918 val_accuracy: 0.418002456

эпоха: 3 потери: 0,258690506 точность: 0,428258181 val_accuracy: 0,427044809

эпоха: 4 потери: 0,25654456 точность: 0,448 * * * 10 * 10 * 10 * * 10 * * 10 * * 10 * * 10 * * 10 * 10 * * 10 * * 10 * 10 * * 10 * * * 10 * 10 * * * * * 10 * 10 * * * * 10 * * 10 * * * * * * 10 * 10 * * : 5 потеря: 0.255373538 точность: 0.444840342 val_accuracy: 0.454993844

эпоха: 6 потеря: 0.254594624 точность: 0.453885168 val_accuracy: 0.454171807

...