Я на самом деле довольно скептически отношусь к своей модели, потому что дискриминатор построен в зависимости от генератора, и он в основном идентичен gan, что заставляет меня думать, что во время обучения генератор и дискриминатор на самом деле не конкурируют сдруг друга состязательно. А также, если вход и выход идентичны для дискриминатора и гана, что означает, что во время вывода я мог бы использовать дискриминатор или ган для восстановления нового изображения, это также кажется подозрительным. Поэтому я попытался построить генератор и дискриминатор (2 входа и 2 выхода) как две независимые сети и связать их вместе, чтобы построить ган. Но, к сожалению, когда я тренирую gan, появляется сообщение об ошибке, что мне нужно передать значения в первый слой дискриминатора.
import keras
from keras import backend as K
from keras.layers import ReLU, LeakyReLU, Conv2D, Conv2DTranspose, BatchNormalization, concatenate, Flatten, Dense, Reshape
from keras.models import Model, clone_model, load_model
import numpy as np
K.clear_session()
# Build autoencoder to be the generator
img_shape = (152, 232, 1)
latent_dim = 16
inputs = keras.Input(shape=img_shape)
x = Conv2D(16, 3, padding='same', strides=(2,2), activation='relu')(inputs)
x = BatchNormalization()(x)
x = Conv2D(32, 3, padding='same', strides=(2,2), activation='relu')(x)
x = BatchNormalization()(x)
shape = K.int_shape(x)
x = Flatten()(x)
latent = Dense(latent_dim, name='latent_vector')(x)
x = Dense(shape[1] * shape[2] * shape[3])(latent)
x = Reshape((shape[1], shape[2], shape[3]))(x)
x = Conv2DTranspose(32, 3, padding='same')(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)
x = Conv2DTranspose(16, 3, padding='same', strides=(2,2))(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)
outputs = Conv2DTranspose(1, 3, padding='same', activation='tanh', strides=(2,2))(x)
generator = Model(inputs, outputs)
generator.summary()
ae_disc = clone_model(generator)
ae_disc.name="autoencoder_discriminator"
inputs_1 = keras.Input(shape=img_shape)
inputs_2 = keras.Input(shape=img_shape)
dis_outputs_1 = ae_disc(inputs_1)
dis_outputs_2 = ae_disc(inputs_2)
# Build discriminator
discriminator = Model([inputs_1, inputs_2], [dis_outputs_1, dis_outputs_2])
# Define loss function for discriminator
loss_d = K.sum(K.abs(inputs_1 - dis_outputs_1)) - K.sum(K.abs(inputs_2 - dis_outputs_2))
discriminator.add_loss(loss_d)
# Compile discriminator
discriminator_optimizer = keras.optimizers.RMSprop(lr=0.0008, clipvalue=1.0, decay=1e-8)
discriminator.compile(optimizer=discriminator_optimizer)
discriminator.summary()
# Freeze discriminator
discriminator.trainable = False
gan_inputs = keras.Input(shape=img_shape)
dis_input_1 = keras.activations.linear(gan_inputs)
dis_input_2 = generator(gan_inputs)
[gan_outputs_1, gan_outputs_2] = discriminator([dis_input_1, dis_input_2])
# Build gan
gan = Model(gan_inputs, [gan_outputs_1, gan_outputs_2])
# Define gan loss
loss_g = K.sum(K.abs(gan_inputs - dis_input_2)) + K.sum(K.abs(dis_input_2 - gan_outputs_2))
gan.add_loss(loss_g)
# Compile gan
gan_optimizer = keras.optimizers.RMSprop(lr=0.0008, clipvalue=1.0, decay=1e-8)
gan.compile(optimizer=gan_optimizer)
gan.summary()
# Train model
# Squeeze pixel values into [-1, 1] since I use 'tanh' as activation for the autoencoder output
x_train = train_imgs.astype('float32') / 255.*2-1
batch_size = 20
start = 0
for step in range(1000):
stop = start + batch_size
images = x_train[start: stop]
generated_images = generator.predict(images)
d_loss = discriminator.train_on_batch([images, generated_images], None)
g_loss = gan.train_on_batch(images, None)
start += batch_size
if start > len(x_train) - batch_size:
start = 0
# Print losses
if step % 10 == 0:
# Print metrics
print('discriminator loss at step %s: %s' % (step, d_loss))
print('generator loss at step %s: %s' % (step, g_loss))
Это сообщение об ошибке, которое я получил: InvalidArgumentError: Вы должны кормитьзначение для тензора заполнителя 'input_2' с плавающей точкой dtype и shape [?, 152,232,1] [[{{node input_2}} = Placeholderdtype = DT_FLOAT, shape = [?, 152,232,1], _device = "/ job: localhost/ replica: 0 / task: 0 / device: CPU: 0 "]]
Кто-нибудь знает, как это исправить? Большое спасибо заранее !!