Я пытаюсь реализовать VAE в keras
самостоятельно, с небольшим руководством из официального учебного руководства keras VAE код здесь - см. Код ниже.
Я получил nan
вмоя потеря, поэтому я попытался выяснить, что вызвало это. Похоже, что потеря kl является причиной этого, потому что когда я запускаю только с потерей kl, я получаю ValueError
высказывание An operation has None for gradient.
Код из учебника keras сработал для меня, и я не могу найти то, чторазница между моим кодом и учебным.
Пожалуйста, не комментируйте, что VAE будет работать хорошо или нет, я просто хочу, чтобы он сначала запустился, а затем отрегулировал его.
from keras.layers import Input, Dense, Flatten, Reshape, Lambda
from keras.models import Model
from keras.losses import mse
from keras.datasets import mnist
import keras.backend as K
import numpy as np
(x_train,y_train),(x_test,y_test) = mnist.load_data()
def sampling(args):
mean, log_var = args
batch = K.shape(mean)[0]
dim = K.int_shape(mean)[1]
epsilon = K.random_normal(shape=(batch,dim))
return mean + K.exp(0.5*log_var)*epsilon
latent_dim = 2
#Encoder
enc_inp = Input(shape=(28,28),name='encoder_input')
x_e = Flatten()(enc_inp)
x_e = Dense(16,activation='relu')(x_e)
x_e = Dense(16,activation='relu')(x_e)
z_mean = Dense(latent_dim,activation='linear')(x_e)
z_log_var = Dense(latent_dim,activation='linear')(x_e)
z = Lambda(sampling)([z_mean,z_log_var])
myEnc = Model(enc_inp,outputs=[z_mean,z_log_var,z])
myEnc.summary()
#Decoder
dec_inp = Input(shape=(latent_dim,))
x_d = Dense(16,activation='relu')(dec_inp)
x_d = Dense(16,activation='relu')(x_d)
x_d = Dense(784,activation='sigmoid')(x_d)
x_d = Reshape((28,28))(x_d)
myDec = Model(dec_inp,x_d)
myDec.summary()
#VAE
outputs = myDec(myEnc(enc_inp)[2])
vae = Model(enc_inp,outputs)
vae.summary()
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
#Only kl_loss for testing as it causes problems
vae.add_loss(kl_loss)
vae.compile(optimizer='adam')
vae.fit(x_train,epochs=10,batch_size=32)
ValueError: An operation has `None` for gradient.
Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.