Безусловные порождающие противозаконные сети в наборе данных MNIST - PullRequest
1 голос
/ 23 декабря 2019

Я обучаю безусловные GAN на наборе данных MNIST с использованием библиотеки tfgan и оценок tfgan. Все работает нормально, и изображения генерируются, см. . Вспомогательные функции для функций модели генератора и дискриминатора пишутся с использованием tf.layers. Но когда я меняю только вспомогательные функции и пишу их, используя tf.keras, тот же самый точный код не работает, и изображения не генерируются, см. . Кто-нибудь может мне помочь с этим? Единственное различие между этими двумя сценариями заключается в изменении вспомогательных функций с использования tf.layers на использование tf.keras. Вспомогательные функции с использованием tf.layers:

def _dense(inputs, units, l2_weight):
  return tf.layers.dense(
      inputs, units, None,
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))

def _batch_norm(inputs, is_training):
  return tf.layers.batch_normalization(
      inputs, momentum=0.999, epsilon=0.001, training=is_training)

def _deconv2d(inputs, filters, kernel_size, stride, l2_weight):
  return tf.layers.conv2d_transpose(
      inputs, filters, [kernel_size, kernel_size], strides=[stride, stride], 
      activation=tf.nn.relu, padding='same',
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))

def _conv2d(inputs, filters, kernel_size, stride, l2_weight):
  return tf.layers.conv2d(
      inputs, filters, [kernel_size, kernel_size], strides=[stride, stride], 
      activation=None, padding='same',
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight)) 

Вспомогательные функции с использованием tf.keras:

def _dense(inputs, units, l2_weight):
  return Dense(units,
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))(inputs)

def _batch_norm(inputs, is_training):
  return BatchNormalization(momentum=0.999, epsilon=0.001)(inputs, training = is_training)


def _deconv2d(inputs, filters, kernel_size, stride, l2_weight):
  return Conv2DTranspose(filters=filters, kernel_size=[kernel_size, kernel_size], strides=[stride, stride],
                                      activation=keras.activations.relu, padding='same',
                                      kernel_initializer=keras.initializers.glorot_uniform,
                                      kernel_regularizer=keras.regularizers.l2(l=l2_weight),
                                      bias_regularizer=keras.regularizers.l2(l=l2_weight))(inputs)

def _conv2d(inputs, filters, kernel_size, stride, l2_weight):
  return Conv2D(filters=filters, kernel_size=[kernel_size, kernel_size], strides=[stride, stride], padding='same',
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))(inputs)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...