Я использую керасы для изучения КТ-изображений, пытаясь тренировать 3D с помощью WGAN-GP. Я использовал этот пример для переписывания и построения остаточной сетевой архитектуры для успешного обучения. Но мои сгенерированные изображения, как правило, одинаковы (как в результате1 и 2). Поскольку я не могу изменить программу оригинального автора в трехмерную форму, я попытался использовать keras для построения каркаса и гиперпараметров, близких к архитектуре оригинального автора . Я использую 41 обучающие данные с формой (срез, строки, столбцы, канал) = (64, 128, 128, 1). У кого-нибудь есть предложения по улучшению? Большое спасибо!
class RandomWeightedAverage(_Merge):
def __init__(self, batch_size):
super().__init__()
self.batch_size = batch_size
"""Provides a (random) weighted average between real and generated image samples"""
def _merge_function(self, inputs):
alpha = K.random_uniform((self.batch_size, 1, 1, 1, 1))
return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
class WGANGP():
def __init__(self, batch_size, epochs, load_Weights=False, sample_interval=50):
self.sample_interval = sample_interval
self.load_Weights = load_Weights
self.epochs = epochs
self.img_rows = 128
self.img_cols = 128
self.img_dim = 64
self.channels = 1
self.init_learning_rate = 0.00005
self.img_shape = (self.img_dim, self.img_rows, self.img_rows, self.channels)
self.latent_dim = 512
self.batch_size = batch_size
self.fileRepeatNumber = 4
self.filePath='...'
self.outputsPath = '...'
self.log_dir = os.path.join(self.outputsPath+"logs",datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),)
# Following parameter and optimizer set as recommended in paper
self.n_discriminator = 5
optimizer = RMSprop(lr=self.init_learning_rate)
# Build the generator and discriminator
self.generator = self.build_generator()
self.discriminator = self.build_Discriminator()
self.generator.summary()
self.discriminator.summary()
#-------------------------------
# Construct Computational Graph
# for the discriminator
#-------------------------------
# Freeze generator's layers while training discriminator
self.generator.trainable = False
# Image input (real sample)
real_img = Input(shape=self.img_shape)
# Noise input
z_disc = Input(shape=(self.latent_dim,))
# Generate image based of noise (fake sample)
fake_img = self.generator(z_disc)
# Discriminator determines validity of the real and fake images
fake = self.discriminator(fake_img)
valid = self.discriminator(real_img)
# Construct weighted average between real and fake images
interpolated_img = RandomWeightedAverage(batch_size=self.batch_size)(inputs=[real_img, fake_img])
# Determine validity of weighted sample
validity_interpolated = self.discriminator(interpolated_img)
# Use Python partial to provide loss function with additional
# 'averaged_samples' argument
partial_gp_loss = partial(self.gradient_penalty_loss,
averaged_samples=interpolated_img)
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names
self.discriminator_model = Model(inputs=[real_img, z_disc],
outputs=[valid, fake, validity_interpolated])
self.discriminator_model.compile(loss=[self.wasserstein_loss,
self.wasserstein_loss,
partial_gp_loss],
optimizer=optimizer,
loss_weights=[1, 1, 10])
#-------------------------------
# Construct Computational Graph
# for Generator
#-------------------------------
# For the generator we freeze the discriminator's layers
self.discriminator.trainable = False
self.generator.trainable = True
# Sampled noise for input to generator
z_gen = Input(shape=(self.latent_dim,))
# Generate images based of noise
img = self.generator(z_gen)
# Discriminator determines validity
valid = self.discriminator(img)
# Defines generator model
self.generator_model = Model(z_gen, valid)
self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)
def gradient_penalty_loss(self, y_true, y_pred, averaged_samples):
"""
Computes gradient penalty based on prediction and weighted real / fake samples
"""
# gradients = K.gradients(self.discriminator_model.get_layer('discriminator').outputs[-1], averaged_samples)[0]
gradients = K.gradients(y_pred, averaged_samples)[0]
# compute the euclidean norm by squaring ...
gradients_sqr = K.square(gradients)
# ... summing over the rows ...
gradients_sqr_sum = K.sum(gradients_sqr,
axis=np.arange(1, len(gradients_sqr.shape)))
# ... and sqrt
gradient_l2_norm = K.sqrt(gradients_sqr_sum)
# compute lambda * (1 - ||grad||)^2 still for each single sample
gradient_penalty = K.square(1 - gradient_l2_norm)
# return the mean as loss over all the batch samples
return K.mean(gradient_penalty)
def wasserstein_loss(self, y_true, y_pred):
return K.mean(y_true * y_pred)
def identity_block(self, X, f, filters, stage, block, Name=None):
# defining name basis
conv_name_base = 'res' + str(stage) + block + Name
bn_name_base = 'bn' + str(stage) + block + Name
# Retrieve Filters
F1, F2, F3 = filters
# Save the input value. You'll need this later to add back to the main path.
X_shortcut = X
# First component of main path
X = BatchNormalization(axis = 4, name = bn_name_base + '2a', momentum=0.8)(X)
X = LeakyReLU(alpha=0.2)(X)
X = Conv3D(filters = F1, kernel_size = (f,f,f), strides = (1,1,1), padding = 'same',
name = conv_name_base + '2a', kernel_initializer = glorot_uniform(seed=0))(X)
# Second component of main path (≈3 lines)
X = BatchNormalization(axis = 4, name = bn_name_base + '2b', momentum=0.8)(X)
X = LeakyReLU(alpha=0.2)(X)
X = Conv3D(filters = F2, kernel_size = (f,f,f), strides = (1,1,1), padding = 'same',
name = conv_name_base + '2b', kernel_initializer = glorot_uniform(seed=0))(X)
# Final step: Add shortcut value to main path, and pass it through a RELU activation (≈2 lines)
X = Add()([X, X_shortcut])
return X
def convolutional_block(self, X, f, filters, stage, block, AvgPooling=False, Name=None, mode=None):
# defining name basis
conv_name_base = 'res' + str(stage) + block + Name
bn_name_base = 'bn' + str(stage) + block + Name
# Retrieve Filters
F1, F2, F3 = filters
# Save the input value
X_shortcut = X
##### MAIN PATH #####
# First component of main path
if AvgPooling is False:
X = UpSampling3D(size=(2,2,2))(X)
X = BatchNormalization(axis = 4, name = bn_name_base + '2a', momentum=0.8)(X)
X = LeakyReLU(alpha=0.2)(X)
X = Conv3D(filters = F1, kernel_size=(f,f,f), strides = (1,1,1), padding = 'same',
name = conv_name_base + '2a', kernel_initializer = glorot_uniform(seed=0))(X)
# Second component of main path (≈3 lines)
X = BatchNormalization(axis = 4, name = bn_name_base + '2b', momentum=0.8)(X)
X = LeakyReLU(alpha=0.2)(X)
X = Conv3D(filters = F2, kernel_size = (f,f,f), strides = (1,1,1), padding = 'same',
name = conv_name_base + '2b', kernel_initializer = glorot_uniform(seed=0))(X)
if AvgPooling:
X = AveragePooling3D(pool_size=(2,2,2), name='AvgPooling'+str(stage)+block)(X)
##### SHORTCUT PATH #### (≈2 lines)
if AvgPooling is False:
X_shortcut = UpSampling3D(size=(2,2,2))(X_shortcut)
if AvgPooling and mode=='opt':
X_shortcut = AveragePooling3D(pool_size=(2,2,2), name='AvgPooling'+str(stage)+block+'_shortcut')(X_shortcut)
X_shortcut = BatchNormalization(axis = 4, name = bn_name_base + '1', momentum=0.8)(X_shortcut)
X_shortcut = LeakyReLU(alpha=0.2)(X_shortcut)
X_shortcut = Conv3D(filters = F2, kernel_size = (f,f,f), strides = (1,1,1), padding = 'same',
name = conv_name_base + '1', kernel_initializer = glorot_uniform(seed=0))(X_shortcut)
if AvgPooling and mode=='down':
X_shortcut = AveragePooling3D(pool_size=(2,2,2), name='AvgPooling'+str(stage)+block+'_shortcut')(X_shortcut)
# Final step: Add shortcut value to main path, and pass it through a RELU activation (≈2 lines)
X = Add()([X, X_shortcut])
return X
def build_generator(self):
# Define the input as a tensor with shape input_shape
X_input = Input(shape=(self.latent_dim,))
# FFC-GAN FC layers
X = Dense(64, name='FC1_Generator')(X_input)
X = Activation('relu')(X)
X = Dense(512, name='FC2_Generator')(X)
X = Activation('relu')(X)
X = Dense(2*4*4*128, name='FC3_Generator')(X)
X = Reshape((2,4,4,128))(X)
X = self.convolutional_block(X, f=3, filters=[256, 256, 512], stage=1, block='a', Name='_Generator')
X = self.convolutional_block(X, f = 3, filters = [128, 128, 256], stage=2, block='a', Name='_Generator')
X = self.convolutional_block(X, f = 3, filters = [64, 64, 128], stage=3, block='a', Name='_Generator')
X = self.convolutional_block(X, f = 3, filters = [32, 32, 64], stage=4, block='a', Name='_Generator')
X = self.convolutional_block(X, f = 3, filters = [16, 16, 32], stage=5, block='a', Name='_Generator')
X = LeakyReLU(alpha=0.2)(X)
X = Conv3D(filters=1 , kernel_size = (1,1,1), strides = (1,1,1), padding = 'same',
name = 'last_Conv_Generator',
kernel_initializer = glorot_uniform(seed=0))(X)
X = Activation('tanh')(X)
return Model(inputs = X_input, outputs = X, name='Generator')
def build_Discriminator(self):
# Define the input as a tensor with shape input_shape
X_input = Input(shape=(self.img_shape))
X = self.convolutional_block(X_input, f=3, filters=[16, 16, 32], stage=0, block='a',
AvgPooling=True, Name='_Discriminator', mode='opt')
X = self.convolutional_block(X_input, f=3, filters=[32, 32, 64], stage=1, block='a',
AvgPooling=True, Name='_Discriminator', mode='down')
X = self.convolutional_block(X, f=3, filters=[64, 64, 128], stage=2, block='a',
AvgPooling=True, Name='_Discriminator', mode='down')
X = self.convolutional_block(X, f=3, filters=[128, 128, 256], stage=3, block='a',
AvgPooling=True, Name='_Discriminator', mode='down')
X = self.identity_block(X, 3, [128, 128, 256], stage=3, block='b', Name='_Discriminator')
X = self.identity_block(X, 3, [128, 128, 256], stage=3, block='c', Name='_Discriminator')
X = LeakyReLU(alpha=0.2)(X)
X = Conv3D(filters = 1, kernel_size = (1,1,1), strides = (1,1,1), padding = 'same',
name = 'last_Conv_Discriminator',
kernel_initializer = glorot_uniform(seed=0))(X)
X = Flatten()(X)
X = Dense(256, name='FC1_Discriminator')(X)
X = Activation('relu')(X)
X = Dense(32, name='FC2_Discriminator')(X)
X = Activation('relu')(X)
X = Dense(16, name='FCXX_Discriminator')(X)
X = Activation('relu')(X)
X = Dense(1, name='FC3_Discriminator')(X)
return Model(inputs = X_input, outputs = X, name='Discriminator')
def train(self):
def runDiscriminator(epoch):
self.writer = tf.summary.FileWriter(self.log_dir)
self.x_train = []
for index, i in enumerate([self.filePath]):
for j in os.listdir(i):
self.x_train.append(i+j)
self.x_train = np.array(self.x_train)
# Adversarial ground truths
self.valid = -np.ones((self.batch_size, 1))
self.fake = np.ones((self.batch_size, 1))
self.dummy = np.zeros((self.batch_size, 1)) # Dummy gt for gradient penalty
self.h_g_loss = []
self.h_d_loss = []
for epoch in range(self.epochs):
for _ in range(self.n_discriminator):
# ---------------------
# Train Discriminator
# ---------------------
imgs = []
idx = np.random.randint(0, self.x_train.shape[0], self.batch_size)
for x in self.x_train[idx]:
imgs.append(np.load(x))
# Rescale -1 to 1
imgs = np.array(imgs).astype(np.float64)
imgs = (imgs - (imgs.max()/2)) / (imgs.max()/2)
imgs = np.expand_dims(imgs, axis=4)
# Sample generator input
noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
# Train the discriminator
d_loss = self.discriminator_model.train_on_batch([imgs, noise], [self.valid, self.fake, self.dummy])
# ---------------------
# Train Generator
# ---------------------
g_loss = self.generator_model.train_on_batch(noise, self.valid)
self.h_g_loss.append(g_loss)
self.h_d_loss.append(d_loss)
summary = tf.Summary(value=[
tf.Summary.Value(tag="G_loss", simple_value=g_loss),
tf.Summary.Value(tag="D_loss_real", simple_value=d_loss[0]),
tf.Summary.Value(tag="D_loss_fake", simple_value=d_loss[1]),])
self.writer.add_summary(summary, epoch)
# Plot the progress
print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
# If at save interval => save generated image samples
if epoch % self.sample_interval == 0:
self.sample_images(epoch)
self.saveLog()
N_ge = self.outputsPath+'generator_model_%d.h5' % epoch
N_di = self.outputsPath+'discriminator_model_%d.h5' % epoch
self.generator.save(N_ge)
self.discriminator.save(N_di)
self.writer.close()
def sample_images(self, epoch):
scaler = MinMaxScaler()
r, c = 8, 8
noise = np.random.normal(0, 1, (1, self.latent_dim))
gen_imgs = self.generator.predict(noise)
fig, axs = plt.subplots(r, c, figsize = ((c*self.img_rows)/100, (r*self.img_rows)/100))
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(scaler.fit_transform(gen_imgs[0,cnt,:,:,0]), cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig(self.outputsPath+"brain_%d.png" % epoch)
plt.close()
def saveLog(self):
plt.plot(np.array(self.h_d_loss)[:,0], label='crit_real')
plt.plot(np.array(self.h_d_loss)[:,1], label='crit_fake')
plt.plot(self.h_g_loss, label='gen')
plt.legend()
plt.grid()
plt.savefig(self.outputsPath+'plot_loss.png')
plt.close()
np.save(self.outputsPath+'h_g_loss', np.array(self.h_g_loss))
np.save(self.outputsPath+'h_d_loss', np.array(self.h_d_loss))
if __name__ == '__main__':
wgan = WGANGP(batch_size=1, epochs=30000, load_Weights=True, sample_interval=50)
wgan.train()
Вот мой тренировочный процесс и результаты
Генерировать результат 1
Генерировать результат 2