Я собрал VAE, используя плотные нейронные сети в Керасе.Во время model.fit
я получаю несоответствие размеров, но не уверен, что выбрасывает код.Ниже показано, как выглядит мой код
from keras.layers import Lambda, Input, Dense
from keras.models import Model
from keras.datasets import mnist
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model
from keras import backend as K
import keras
import numpy as np
import matplotlib.pyplot as plt
import argparse
import os
(x_train, y_train), (x_test, y_test) = mnist.load_data()
image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
# network parameters
input_shape = (original_dim, )
intermediate_dim = 512
batch_size = 128
latent_dim = 2
epochs = 50
x = Input(batch_shape=(batch_size, original_dim))
h = Dense(intermediate_dim, activation='relu')(x)
z_mean = Dense(latent_dim)(h)
z_log_sigma = Dense(latent_dim)(h)
def sampling(args):
z_mean, z_log_sigma = args
#epsilon = K.random_normal(shape=(batch, dim))
epsilon = K.random_normal(shape=(batch_size, latent_dim))
return z_mean + K.exp(z_log_sigma) * epsilon
# note that "output_shape" isn't necessary with the TensorFlow backend
# so you could write `Lambda(sampling)([z_mean, z_log_sigma])`
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_sigma])
decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)
print('X Decoded Mean shape: ', x_decoded_mean.shape)
# end-to-end autoencoder
vae = Model(x, x_decoded_mean)
# encoder, from inputs to latent space
encoder = Model(x, z_mean)
# generator, from latent space to reconstructed inputs
decoder_input = Input(shape=(latent_dim,))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)
def vae_loss(x, x_decoded_mean):
xent_loss = keras.metrics.binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
return xent_loss + kl_loss
vae.compile(optimizer='rmsprop', loss=vae_loss)
print('X train shape: ', x_train.shape)
print('X test shape: ', x_test.shape)
vae.fit(x_train, x_train,
shuffle=True,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, x_test))
Вот трассировка стека, которую я вижу, когда вызывается model.fit
.
File "/home/asattar/workspace/projects/keras-examples/blogautoencoder/VariationalAutoEncoder.py", line 81, in <module>
validation_data=(x_test, x_test))
File "/usr/local/lib/python2.7/dist-packages/Keras-2.2.4-py2.7.egg/keras/engine/training.py", line 1047, in fit
validation_steps=validation_steps)
File "/usr/local/lib/python2.7/dist-packages/Keras-2.2.4-py2.7.egg/keras/engine/training_arrays.py", line 195, in fit_loop
outs = fit_function(ins_batch)
File "/usr/local/lib/python2.7/dist-packages/Keras-2.2.4-py2.7.egg/keras/backend/tensorflow_backend.py", line 2897, in __call__
return self._call(inputs)
File "/usr/local/lib/python2.7/dist-packages/Keras-2.2.4-py2.7.egg/keras/backend/tensorflow_backend.py", line 2855, in _call
fetched = self._callable_fn(*array_vals)
File "/home/asattar/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1439, in __call__
run_metadata_ptr)
File "/home/asattar/.local/lib/python2.7/site-packages/tensorflow/python/framework/errors_impl.py", line 528, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [128,784] vs. [96,784]
[[{{node training/RMSprop/gradients/loss/dense_5_loss/logistic_loss/mul_grad/BroadcastGradientArgs}} = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@train...ad/Reshape"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training/RMSprop/gradients/loss/dense_5_loss/logistic_loss/mul_grad/Shape, training/RMSprop/gradients/loss/dense_5_loss/logistic_loss/mul_grad/Shape_1)]]
Обратите внимание на "Несовместимые фигуры: [128,784] против [96,784]" в трассе стека "в конце трассы.