Резюме : Я пытаюсь переучить простой CNN для MNIST без использования API высокого уровня.Я уже преуспел в этом, переобучив всю сеть, но моя текущая цель - переобучить только последние один или два полностью подключенных слоя.
Работа на данный момент: Допустим, у меня есть CNN со следующей структурой
- Сверточный слой
- RELU
- Объединение в пулСлой
- Сверточный слой
- RELU
- Объединяющий слой
- Полностью связанный слой
- RELU
- Выпадающий слой
- Полностью подключенный слой к 10 выходным классам
Моя цель - переобучить либо последний полностью подключенный слой, либо последние два полностью подключенных слоя.
Пример сверточного слоя:
W_conv1 = tf.get_variable("W", [5, 5, 1, 32],
initializer=tf.truncated_normal_initializer(stddev=np.sqrt(2.0 / 784)))
b_conv1 = tf.get_variable("b", initializer=tf.constant(0.1, shape=[32]))
z = tf.nn.conv2d(x_image, W_conv1, strides=[1, 1, 1, 1], padding='SAME')
z += b_conv1
h_conv1 = tf.nn.relu(z + b_conv1)
Пример полностью подключенного слоя:
input_size = 7 * 7 * 64
W_fc1 = tf.get_variable("W", [input_size, 1024], initializer=tf.truncated_normal_initializer(stddev=np.sqrt(2.0/input_size)))
b_fc1 = tf.get_variable("b", initializer=tf.constant(0.1, shape=[1024]))
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
Мое предположение : при выполнениипри обратном распространении нового набора данных я просто проверяю, что мои веса W и b (из W * x + b) зафиксированы в не полностью связанных слоях.
Первая мысль о том, как это сделать : сохраните W и b, выполните шаг обратного распространения и замените новый W и b старым в слоях, которые я не делаю.хочу поменять.
Мои мысли об этом первом подходе :
- Это требует больших вычислительных ресурсов и тратит память.Единственное преимущество выполнения только последнего слоя состоит в том, что нет необходимости выполнять другие
- Обратное распространение может работать по-разному, если не применяется ко всем слоям?
Myвопрос :
- Как правильно переобучить отдельные слои в нейронной сети, когда не используются API высокого уровня.Приветствуются как концептуальные, так и кодовые ответы.
PS Полностью осознавая, как это можно сделать с помощью API высокого уровня.Пример: https://towardsdatascience.com/how-to-train-your-model-dramatically-faster-9ad063f0f718. Просто не хочу, чтобы нейронные сети были волшебными, я хочу знать, что на самом деле происходит