Ошибка значения с формой дискриминатора в tflearn DCGAN - PullRequest
0 голосов
/ 27 мая 2019

Я пытаюсь создать DCGAN для использования с настраиваемым набором данных с использованием инфраструктуры tflearn.

В настоящее время я считаю, что проблема заключается в коде дискриминатора, но я не уверен.Если вы измените форму во втором слое conv_2d, он изменит первое число в сообщении об ошибке.

def discriminator(x, reuse=False):
    with tf.variable_scope('Discriminator', reuse=reuse):
        x = tflearn.conv_2d(x, 64, 5, activation='tanh')
        x = tflearn.avg_pool_2d(x, 2)
        x = tflearn.conv_2d(x, 6272, 5, activation='tanh')
        x = tflearn.avg_pool_2d(x, 2)
        x = tflearn.fully_connected(x, 1024, activation='tanh')
        x = tflearn.fully_connected(x, 2)
        x = tf.nn.softmax(x)
        return x


# Input Data
gen_input = tflearn.input_data(shape=[None, z_dim], name='input_gen_noise')
input_disc_noise = tflearn.input_data(shape=[None, z_dim], name='input_disc_noise')
input_disc_real = tflearn.input_data(shape=[None, 400, 400, 1], name='input_disc_real')

# Build Discriminator
disc_fake = discriminator(generator(input_disc_noise))
disc_real = discriminator(input_disc_real, reuse=True)
disc_net = tf.concat([disc_fake, disc_real], axis=0)
# Build Stacked Generator/Discriminator
gen_net = generator(gen_input, reuse=True)
stacked_gan_net = discriminator(gen_net, reuse=True)

Ошибка с кодом ошибки:

Traceback (most recent call last):
  File "dcgan.py", line 83, in <module>
    disc_real = discriminator(input_disc_real, reuse=True)
  File "dcgan.py", line 70, in discriminator
    x = tflearn.fully_connected(x, 1024, activation='tanh')
  File "D:\python\lib\site-packages\tflearn\layers\core.py", line 157, in fully_connected
    restore=restore)
  File "D:\python\lib\site-packages\tensorflow\contrib\framework\python\ops\arg_scope.py", line 182, in func_with_args
    return func(*args, **current_args)
  File "D:\python\lib\site-packages\tflearn\variables.py", line 65, in variable
    validate_shape=validate_shape)
  File "D:\python\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1479, in get_variable
    aggregation=aggregation)
  File "D:\python\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1220, in get_variable
    aggregation=aggregation)
  File "D:\python\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 547, in get_variable
    aggregation=aggregation)
  File "D:\python\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 499, in _true_getter
    aggregation=aggregation)
  File "D:\python\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 853, in _get_single_variable
    found_var.get_shape()))
ValueError: Trying to share variable Discriminator/FullyConnected/W, but specified shape (62720000, 1024) and found shape (307328, 1024).
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...