Генеративная ошибка конвергенции состязательной сети - PullRequest
0 голосов
/ 04 августа 2020

Я новичок в области глубокого обучения. Я пытаюсь разработать «Генеративную состязательную сеть», следуя примеру, приведенному на веб-сайте TensorFlow. К сожалению, я получаю ошибки и не могу запустить его. Я публикую здесь код, до которого он может быть выполнен, а также сообщаю об ошибке, которую необходимо устранить.

import numpy as np 
import matplotlib.pyplot as plt 
import pandas as pd
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.layers import Input, Dense, Reshape, Flatten, Dropout 
from keras.layers import BatchNormalization, Activation
from keras.layers.advanced_activations import LeakyReLU 
from keras.layers.convolutional import Conv2D 
from keras.models import Sequential, Model 
from keras.optimizers import SGD 

def makediag3d(a):
    a = np.asarray(a)
    depth, size = a.shape
    x = np.zeros((depth,size,size))
    for i in range(depth):
        x[i].flat[slice(0,None,1+size)] = a[i]
   return x

input = pd.read_excel( '3_height.xls')
new_input = (makediag3d(input))
output = pd.read_excel( '3_absor.xls')
new_output= np.array(output)

BUFFER_SIZE = 50000 BATCH_SIZE = 128 train_dataset =

tf.data .Dataset.from_tensor_slices (new_input) .shuffle (BUFFER_SIZE) .batch (BATCH_SIZE)

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(3*3*256, use_bias=False, input_shape=(50,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((3, 3, 256)))
    assert model.output_shape == (None, 7, 7, 256)
    model.add(layers.Conv2DTranspose(128, (5, 5), strides(1,1),padding='same',  use_bias=False)) 
    assert model.output_shape == (None, 3, 3, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5),strides(2,2),padding='same', 
    use_bias=False))
    assert model.output_shape == (None, 3, 3, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2),padding='same',  

    use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 3, 3, 1)

    return model
generator = make_generator_model()

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                 input_shape=[3, 3, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.2))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.2))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model
discriminator = make_discriminator_model()

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                             discriminator_optimizer=discriminator_optimizer,
                             generator=generator,
                             discriminator=discriminator)

Ошибка, которую я получаю:

NameError                                 Traceback (most recent call last)
<ipython-input-84-0106cda2248f> in <module>
      3 checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
      4                                  discriminator_optimizer=discriminator_optimizer,
----> 5                                  generator=generator,
      6                                  discriminator=discriminator)
      7 

NameError: name 'generator' is not defined

Я также предоставляю ссылку на весь код для ознакомления.

https://drive.google.com/file/d/1VJOARa9XvPWtcjQFbd0jATLsbkeEzJcl/view? usp = обмен

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...