GAN: Как обновить параметр нормализации партии в Tensorflow - PullRequest
0 голосов
/ 09 мая 2019

Пакетное обновление нормализации

  • Я учусь (Генеративная Состязательная Сеть), используя Нормализация партии в Tensorflow.
  • Как обновить параметр нормализации партии (скользящее среднее, скользящая дисперсия) в Tensorflow.
  • Как обновить график после использования batch_normalization на дискриминаторе и генераторе?
  • Или автоматически обновляется с использованием функции batch_normalization в дискриминаторе и генераторе?

Полный исходный код Ссылка

Среда разработки

  • PyCharm
  • Платформа ОС и ее распространение: Windows 10 x64
  • Установлен TensorFlow от: Anaconda
  • Tensorflow версия 1.12.0:
  • python 3.6.7:
  • Мобильное устройство: N / A
  • Точная команда для воспроизведения: N / A
  • Модель GPU и память: NVIDIA GeForce CTX 1080 Ti
  • CUDA / cuDNN: 9,0 / 7,4

Дискриминатор

# discriminator
def discriminator(x , train_state, reuse = False):
    with tf.variable_scope(name_or_scope="Dis", reuse=reuse) as scope:
        dw1 = tf.get_variable(name="w1", shape=[num_input, num_hidden], initializer=myinit)
        db1 = tf.get_variable(name="b1", shape=[num_hidden], initializer=myinit)
        dw2 = tf.get_variable(name="w2", shape=[num_hidden, num_output], initializer=myinit)
        db2 = tf.get_variable(name="b2", shape=[num_output], initializer=myinit)

    fcHidden = tf.matmul(x, dw1) + db1
    bnHidden = tf.layers.batch_normalization(fcHidden, training=train_state)
    # hidden = tf.nn.leaky_relu(bnHidden)
    hidden = tf.nn.relu(bnHidden)
    logits = tf.matmul(hidden, dw2) + db2
    bnLogits = tf.layers.batch_normalization(logits, training=train_state)
    output = tf.nn.sigmoid(bnLogits)
    return output, logits

Генератор

# generator
def generator(z, train_state):
    with tf.variable_scope(name_or_scope="Gen") as scope:
        gw1 = tf.get_variable(name="w1", shape=[num_noise, num_hidden], initializer=myinit)
        gb1 = tf.get_variable(name="b1", shape=[num_hidden], initializer=myinit)
        gw2 = tf.get_variable(name="w2", shape=[num_hidden, num_input], initializer=myinit)
        gb2 = tf.get_variable(name="b2", shape=[num_input], initializer=myinit)

    fcHidden = tf.matmul(z, gw1) + gb1
    bnHidden = tf.layers.batch_normalization(fcHidden, training=train_state)
    # hidden = tf.nn.leaky_relu(bnHidden)
    hidden = tf.nn.relu(bnHidden)
    logits = tf.matmul(hidden, gw2) + gb2
    bnLogits = tf.layers.batch_normalization(logits, training=train_state)
    output = tf.nn.sigmoid(bnLogits)
    return output, logits, hidden, tf.nn.leaky_relu(fcHidden)

Graph

g = tf.Graph()
with g.as_default():
    X = tf.placeholder(tf.float32, [None, num_input]) # GAN 은 unsupervised learning 이므로 y(label)을 사용하지 않습니다.
    Z = tf.placeholder(tf.float32, [None, num_noise]) # Z 는 생성기의 입력값으로 사용될 noise 입니다.
    preLabel = tf.placeholder(tf.float32, [None, 1])
    trainingState = tf.placeholder(tf.bool)

    # Pre-train
    result_of_pre, logits_pre = discriminator_pre(X)
    p_loss = tf.reduce_mean(tf.square(result_of_pre - preLabel))

    # Discriminator & Generator
    fake_x, fake_logits, ghidden, gohidden  = generator(Z, trainingState)
    result_of_real, logits_real = discriminator(X, trainingState)
    result_of_fake, logits_fake = discriminator(fake_x, trainingState, True)

    # Discriminator / Generator 손실 함수를 정의합니다.
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=tf.ones_like(result_of_real)))  # log(D(x))
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=tf.zeros_like(result_of_fake)))  # log(1-D(G(z)))
    d_loss = d_loss_real + d_loss_fake  # log(D(x)) + log(1-D(G(z)))
    g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=tf.ones_like(result_of_fake)))  # log(D(G(z)))

    # Parameter
    t_vars = tf.trainable_variables() # return list
    g_vars = [var for var in t_vars if "Gen" in var.name]
    d_vars = [var for var in t_vars if "Dis" in var.name]
    p_vars = [var for var in t_vars if "Pre" in var.name]

    # Optimizer / Gradient
    p_train = tf.train.AdamOptimizer(learning_rate=learningrate_dis, beta1=0.5, beta2=0.999).minimize(p_loss, var_list=p_vars)
    g_train = tf.train.AdamOptimizer(learning_rate=learningrate_gen, beta1=0.5, beta2=0.999).minimize(g_loss, var_list=g_vars)
    d_train = tf.train.AdamOptimizer(learning_rate=learningrate_dis, beta1=0.5, beta2=0.999).minimize(d_loss, var_list=d_vars)

    # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # with tf.control_dependencies(update_ops):
    #     g_train = tf.train.AdamOptimizer(learning_rate=learningrate_gen, beta1=0.5, beta2=0.999).minimize(g_loss, var_list=g_vars)
    #     d_train = tf.train.AdamOptimizer(learning_rate=learningrate_dis, beta1=0.5, beta2=0.999).minimize(d_loss, var_list=d_vars)

    # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # p_train = tf.train.AdamOptimizer(learning_rate=learningrate_dis, beta1=0.5, beta2=0.999).minimize(p_loss, var_list=p_vars)
    # p_train = tf.group([p_train, update_ops])

Пакетное обновление нормализации

# Optimizer / Gradient
    # Method 1
    p_train = tf.train.AdamOptimizer(learning_rate=learningrate_dis, beta1=0.5, beta2=0.999).minimize(p_loss, var_list=p_vars)
    g_train = tf.train.AdamOptimizer(learning_rate=learningrate_gen, beta1=0.5, beta2=0.999).minimize(g_loss, var_list=g_vars)
    d_train = tf.train.AdamOptimizer(learning_rate=learningrate_dis, beta1=0.5, beta2=0.999).minimize(d_loss, var_list=d_vars)

    # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # Method 2
    # with tf.control_dependencies(update_ops):
    #     g_train = tf.train.AdamOptimizer(learning_rate=learningrate_gen, beta1=0.5, beta2=0.999).minimize(g_loss, var_list=g_vars)
    #     d_train = tf.train.AdamOptimizer(learning_rate=learningrate_dis, beta1=0.5, beta2=0.999).minimize(d_loss, var_list=d_vars)

    # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # Method 3
    # p_train = tf.train.AdamOptimizer(learning_rate=learningrate_dis, beta1=0.5, beta2=0.999).minimize(p_loss, var_list=p_vars)
    # p_train = tf.group([p_train, update_ops])
...