Модель не тренировалась при использовании tf.Session () - PullRequest
0 голосов
/ 08 апреля 2020

Я новичок в TensorFlow и Keras. Я пытаюсь понять GAN, используя TF 1.x (используя этот репозиторий https://github.com/hse-aml), и у меня возникли проблемы с функцией ниже, которая использовалась для создания сеанса. Моя проблема в том, что именно делает эта функция (почему мы не можем использовать tf.Session () в одиночку). Когда я использую tf.Session (), модель не тренировалась.

from keras import backend as K

def weird_session():
    curr_session = tf.get_default_session()
    if curr_session is not None:
        curr_session.close()
    K.clear_session()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    s = tf.InteractiveSession(config=config)
    K.set_session(s)
    return s
s=weird_session()

Это полный код, который я использовал.

%tensorflow_version 1.x
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers as L
import numpy as np


IMG_SHAPE=(36,36,3)
####################################################################################### DOWNLOAD DATASET
!git clone https://github.com/RaviSoji/colab_utils.git  # Include the "!".
import colab_utils
drive = colab_utils.get_gdrive()
colab_utils.pull_from_gdrive(drive, 'GAN/my.npy','hah.npy')
dataset=np.load('hah.npy')
plt.imshow(dataset[0])

data = np.float32(dataset)/255.
########################################################################################

from keras import backend as K

####################################################  weird_session() function -the problem
def weird_session():
    curr_session = tf.get_default_session()
    if curr_session is not None:
        curr_session.close()
    K.clear_session()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    s = tf.InteractiveSession(config=config)
    K.set_session(s)
    return s
##################################################
s=weird_session()


IMG_SHAPE = data.shape[1:]

CODE_SIZE = 256

generator = Sequential()
generator.add(L.InputLayer([CODE_SIZE],name='noise'))
generator.add(L.Dense(10*8*8, activation='elu'))
generator.add(L.Reshape((8,8,10)))
generator.add(L.Conv2DTranspose(64,kernel_size=(5,5),activation='elu'))
generator.add(L.Conv2DTranspose(64,kernel_size=(5,5),activation='elu'))
generator.add(L.UpSampling2D(size=(2,2)))
generator.add(L.Conv2DTranspose(32,kernel_size=3,activation='elu'))
generator.add(L.Conv2DTranspose(32,kernel_size=3,activation='elu'))
generator.add(L.Conv2DTranspose(32,kernel_size=3,activation='elu'))
generator.add(L.Conv2D(3,kernel_size=3,activation=None))

discriminator = Sequential()
discriminator.add(L.InputLayer(IMG_SHAPE))
discriminator.add(L.Conv2D(filters=16, kernel_size=(3, 3), strides=1))
discriminator.add(L.LeakyReLU(0.12))
discriminator.add(L.Conv2D(filters=32, kernel_size=(3, 3), strides=1))
discriminator.add(L.LeakyReLU(0.12))
discriminator.add(L.MaxPool2D(pool_size=(2, 2)))
discriminator.add(L.Conv2D(filters=32, kernel_size=(3, 3), strides=1))
discriminator.add(L.LeakyReLU(0.12))
discriminator.add(L.Conv2D(filters=64, kernel_size=(3, 3), strides=1))
discriminator.add(L.LeakyReLU(0.12))
discriminator.add(L.MaxPool2D(pool_size=(2, 2))) 
discriminator.add(L.Flatten())
discriminator.add(L.Dense(256,activation='tanh'))
discriminator.add(L.Dense(2,activation=tf.nn.log_softmax))


noise = tf.placeholder('float32',[None,CODE_SIZE])
real_data = tf.placeholder('float32',[None,]+list(IMG_SHAPE))
logp_real = discriminator(real_data)
generated_data = generator(noise)
logp_gen = discriminator(generated_data)

d_loss = -tf.reduce_mean(logp_real[:,1] + logp_gen[:,0])
d_loss += tf.reduce_mean(discriminator.layers[-1].kernel**2)
disc_optimizer =  tf.train.GradientDescentOptimizer(1e-3).minimize(d_loss,var_list=discriminator.trainable_weights)

g_loss = -tf.reduce_mean(logp_gen[:,1])
gen_optimizer = tf.train.AdamOptimizer(1e-4).minimize(g_loss,var_list=generator.trainable_weights)

s.run(tf.global_variables_initializer())


def sample_noise_batch(bsize):
    return np.random.normal(size=(bsize, CODE_SIZE)).astype('float32')

def sample_data_batch(bsize):
    idxs = np.random.choice(np.arange(data.shape[0]), size=bsize)
    return data[idxs]

def sample_images(nrow,ncol, sharp=False):
    images = generator.predict(sample_noise_batch(bsize=nrow*ncol))
    if np.var(images)!=0:
        images = images.clip(np.min(data),np.max(data))
    for i in range(nrow*ncol):
        plt.subplot(nrow,ncol,i+1)
        if sharp:
            plt.imshow(images[i].reshape(IMG_SHAPE),cmap="gray", interpolation="none")
        else:
            plt.imshow(images[i].reshape(IMG_SHAPE),cmap="gray")
    plt.show()

def sample_probas(bsize):
    plt.title('Generated vs real data')
    plt.hist(np.exp(discriminator.predict(sample_data_batch(bsize)))[:,1],
             label='D(x)', alpha=0.5,range=[0,1])
    plt.hist(np.exp(discriminator.predict(generator.predict(sample_noise_batch(bsize))))[:,1],
             label='D(G(z))',alpha=0.5,range=[0,1])
    plt.legend(loc='best')
    plt.show()

from IPython import display

for epoch in range(50000):
    feed_dict = {
        real_data:sample_data_batch(100),
        noise:sample_noise_batch(100)
    }

    for i in range(5):
        s.run(disc_optimizer,feed_dict)
    s.run(gen_optimizer,feed_dict)

    if epoch %100==0:
        display.clear_output(wait=True)
        sample_images(2,3,True)

1 Ответ

0 голосов
/ 09 апреля 2020

Документация здесь

Единственная разница между Session и InteractiveSession состоит в том, что InteractiveSession делает себя сеансом по умолчанию, так что вы можете вызвать run() или eval() без явного вызова сеанса.

Это может быть полезно, если вы экспериментируете с TF в оболочке python или в записных книжках Jupyter, поскольку при этом не требуется передавать явный объект Session для запуска операций.

Так что, если вы используете просто tf.Session, вам также потребуется сделать его сеансом по умолчанию.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...