Обучение 3D КТ изображений с использованием керасов и WGAN-GP - PullRequest
0 голосов
/ 02 апреля 2020

Я использую керасы для изучения КТ-изображений, пытаясь тренировать 3D с помощью WGAN-GP. Я использовал этот пример для переписывания и построения остаточной сетевой архитектуры для успешного обучения. Но мои сгенерированные изображения, как правило, одинаковы (как в результате1 и 2). Поскольку я не могу изменить программу оригинального автора в трехмерную форму, я попытался использовать keras для построения каркаса и гиперпараметров, близких к архитектуре оригинального автора . Я использую 41 обучающие данные с формой (срез, строки, столбцы, канал) = (64, 128, 128, 1). У кого-нибудь есть предложения по улучшению? Большое спасибо!

class RandomWeightedAverage(_Merge):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        alpha = K.random_uniform((self.batch_size, 1, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

class WGANGP():
    def __init__(self, batch_size, epochs, load_Weights=False, sample_interval=50):
        self.sample_interval = sample_interval
        self.load_Weights = load_Weights
        self.epochs = epochs
        self.img_rows = 128
        self.img_cols = 128
        self.img_dim = 64
        self.channels = 1
        self.init_learning_rate = 0.00005
        self.img_shape = (self.img_dim, self.img_rows, self.img_rows, self.channels)
        self.latent_dim = 512
        self.batch_size = batch_size
        self.fileRepeatNumber = 4

        self.filePath='...'
        self.outputsPath = '...'
        self.log_dir = os.path.join(self.outputsPath+"logs",datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),)

        # Following parameter and optimizer set as recommended in paper
        self.n_discriminator = 5
        optimizer = RMSprop(lr=self.init_learning_rate)

        # Build the generator and discriminator
        self.generator = self.build_generator()
        self.discriminator = self.build_Discriminator()
        self.generator.summary()
        self.discriminator.summary()

        #-------------------------------
        # Construct Computational Graph
        #       for the discriminator
        #-------------------------------
        # Freeze generator's layers while training discriminator
        self.generator.trainable = False

        # Image input (real sample)
        real_img = Input(shape=self.img_shape)

        # Noise input
        z_disc = Input(shape=(self.latent_dim,))
        # Generate image based of noise (fake sample)
        fake_img = self.generator(z_disc)

        # Discriminator determines validity of the real and fake images
        fake = self.discriminator(fake_img)
        valid = self.discriminator(real_img)

        # Construct weighted average between real and fake images
        interpolated_img = RandomWeightedAverage(batch_size=self.batch_size)(inputs=[real_img, fake_img])
        # Determine validity of weighted sample
        validity_interpolated = self.discriminator(interpolated_img)

        # Use Python partial to provide loss function with additional
        # 'averaged_samples' argument
        partial_gp_loss = partial(self.gradient_penalty_loss,
                          averaged_samples=interpolated_img)
        partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names

        self.discriminator_model = Model(inputs=[real_img, z_disc],
                            outputs=[valid, fake, validity_interpolated])
        self.discriminator_model.compile(loss=[self.wasserstein_loss,
                                        self.wasserstein_loss,
                                        partial_gp_loss],
                                        optimizer=optimizer,
                                        loss_weights=[1, 1, 10])
        #-------------------------------
        # Construct Computational Graph
        #         for Generator
        #-------------------------------
        # For the generator we freeze the discriminator's layers
        self.discriminator.trainable = False
        self.generator.trainable = True

        # Sampled noise for input to generator
        z_gen = Input(shape=(self.latent_dim,))
        # Generate images based of noise
        img = self.generator(z_gen)
        # Discriminator determines validity
        valid = self.discriminator(img)
        # Defines generator model
        self.generator_model = Model(z_gen, valid)
        self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)

    def gradient_penalty_loss(self, y_true, y_pred, averaged_samples):
        """
        Computes gradient penalty based on prediction and weighted real / fake samples
        """
        # gradients = K.gradients(self.discriminator_model.get_layer('discriminator').outputs[-1], averaged_samples)[0]
        gradients = K.gradients(y_pred, averaged_samples)[0]
        # compute the euclidean norm by squaring ...
        gradients_sqr = K.square(gradients)
        #   ... summing over the rows ...
        gradients_sqr_sum = K.sum(gradients_sqr,
                                  axis=np.arange(1, len(gradients_sqr.shape)))
        #   ... and sqrt
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        # compute lambda * (1 - ||grad||)^2 still for each single sample
        gradient_penalty = K.square(1 - gradient_l2_norm)
        # return the mean as loss over all the batch samples
        return K.mean(gradient_penalty)


    def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)

    def identity_block(self, X, f, filters, stage, block, Name=None):
        # defining name basis
        conv_name_base = 'res' + str(stage) + block + Name
        bn_name_base = 'bn' + str(stage) + block + Name

        # Retrieve Filters
        F1, F2, F3 = filters

        # Save the input value. You'll need this later to add back to the main path. 
        X_shortcut = X

        # First component of main path
        X = BatchNormalization(axis = 4, name = bn_name_base + '2a', momentum=0.8)(X)
        X = LeakyReLU(alpha=0.2)(X)
        X = Conv3D(filters = F1, kernel_size = (f,f,f), strides = (1,1,1), padding = 'same', 
            name = conv_name_base + '2a', kernel_initializer = glorot_uniform(seed=0))(X)

        # Second component of main path (≈3 lines)
        X = BatchNormalization(axis = 4, name = bn_name_base + '2b', momentum=0.8)(X)
        X = LeakyReLU(alpha=0.2)(X)
        X = Conv3D(filters = F2, kernel_size = (f,f,f), strides = (1,1,1), padding = 'same', 
            name = conv_name_base + '2b', kernel_initializer = glorot_uniform(seed=0))(X)

        # Final step: Add shortcut value to main path, and pass it through a RELU activation (≈2 lines)
        X = Add()([X, X_shortcut])

        return X

    def convolutional_block(self, X, f, filters, stage, block, AvgPooling=False, Name=None, mode=None):
        # defining name basis
        conv_name_base = 'res' + str(stage) + block + Name
        bn_name_base = 'bn' + str(stage) + block + Name

        # Retrieve Filters
        F1, F2, F3 = filters

        # Save the input value
        X_shortcut = X

        ##### MAIN PATH #####
        # First component of main path
        if AvgPooling is False:
            X = UpSampling3D(size=(2,2,2))(X)

        X = BatchNormalization(axis = 4, name = bn_name_base + '2a', momentum=0.8)(X)
        X = LeakyReLU(alpha=0.2)(X)
        X = Conv3D(filters = F1, kernel_size=(f,f,f), strides = (1,1,1), padding = 'same',
            name = conv_name_base + '2a', kernel_initializer = glorot_uniform(seed=0))(X)

        # Second component of main path (≈3 lines)
        X = BatchNormalization(axis = 4, name = bn_name_base + '2b', momentum=0.8)(X)
        X = LeakyReLU(alpha=0.2)(X)
        X = Conv3D(filters = F2, kernel_size = (f,f,f), strides = (1,1,1), padding = 'same', 
            name = conv_name_base + '2b', kernel_initializer = glorot_uniform(seed=0))(X)

        if AvgPooling:
            X = AveragePooling3D(pool_size=(2,2,2), name='AvgPooling'+str(stage)+block)(X)

        ##### SHORTCUT PATH #### (≈2 lines)
        if AvgPooling is False:
            X_shortcut = UpSampling3D(size=(2,2,2))(X_shortcut)

        if AvgPooling and mode=='opt':
            X_shortcut = AveragePooling3D(pool_size=(2,2,2), name='AvgPooling'+str(stage)+block+'_shortcut')(X_shortcut)
        X_shortcut = BatchNormalization(axis = 4, name = bn_name_base + '1', momentum=0.8)(X_shortcut)
        X_shortcut = LeakyReLU(alpha=0.2)(X_shortcut)
        X_shortcut = Conv3D(filters = F2, kernel_size = (f,f,f), strides = (1,1,1), padding = 'same', 
            name = conv_name_base + '1', kernel_initializer = glorot_uniform(seed=0))(X_shortcut)
        if AvgPooling and mode=='down':
            X_shortcut = AveragePooling3D(pool_size=(2,2,2), name='AvgPooling'+str(stage)+block+'_shortcut')(X_shortcut)

        # Final step: Add shortcut value to main path, and pass it through a RELU activation (≈2 lines)
        X = Add()([X, X_shortcut])

        return X

    def build_generator(self):
        # Define the input as a tensor with shape input_shape
        X_input = Input(shape=(self.latent_dim,))
        # FFC-GAN FC layers
        X = Dense(64, name='FC1_Generator')(X_input)
        X = Activation('relu')(X)
        X = Dense(512, name='FC2_Generator')(X)
        X = Activation('relu')(X)
        X = Dense(2*4*4*128, name='FC3_Generator')(X)
        X = Reshape((2,4,4,128))(X)
        X = self.convolutional_block(X, f=3, filters=[256, 256, 512], stage=1, block='a', Name='_Generator')
        X = self.convolutional_block(X, f = 3, filters = [128, 128, 256], stage=2, block='a', Name='_Generator')
        X = self.convolutional_block(X, f = 3, filters = [64, 64, 128], stage=3, block='a', Name='_Generator')
        X = self.convolutional_block(X, f = 3, filters = [32, 32, 64], stage=4, block='a', Name='_Generator')
        X = self.convolutional_block(X, f = 3, filters = [16, 16, 32], stage=5, block='a', Name='_Generator')  
        X = LeakyReLU(alpha=0.2)(X)
        X = Conv3D(filters=1 , kernel_size = (1,1,1), strides = (1,1,1), padding = 'same', 
                   name = 'last_Conv_Generator', 
                   kernel_initializer = glorot_uniform(seed=0))(X)
        X = Activation('tanh')(X)

        return Model(inputs = X_input, outputs = X, name='Generator')

    def build_Discriminator(self):
        # Define the input as a tensor with shape input_shape
        X_input = Input(shape=(self.img_shape))
        X = self.convolutional_block(X_input, f=3, filters=[16, 16, 32], stage=0, block='a', 
            AvgPooling=True, Name='_Discriminator', mode='opt')
        X = self.convolutional_block(X_input, f=3, filters=[32, 32, 64], stage=1, block='a', 
            AvgPooling=True, Name='_Discriminator', mode='down')
        X = self.convolutional_block(X, f=3, filters=[64, 64, 128], stage=2, block='a', 
            AvgPooling=True, Name='_Discriminator', mode='down')
        X = self.convolutional_block(X, f=3, filters=[128, 128, 256], stage=3, block='a',
            AvgPooling=True, Name='_Discriminator', mode='down')
        X = self.identity_block(X, 3, [128, 128, 256], stage=3, block='b', Name='_Discriminator')
        X = self.identity_block(X, 3, [128, 128, 256], stage=3, block='c', Name='_Discriminator')

        X = LeakyReLU(alpha=0.2)(X)
        X = Conv3D(filters = 1, kernel_size = (1,1,1), strides = (1,1,1), padding = 'same', 
                   name = 'last_Conv_Discriminator', 
                   kernel_initializer = glorot_uniform(seed=0))(X)
        X = Flatten()(X)
        X = Dense(256, name='FC1_Discriminator')(X)
        X = Activation('relu')(X)
        X = Dense(32, name='FC2_Discriminator')(X)
        X = Activation('relu')(X)
        X = Dense(16, name='FCXX_Discriminator')(X)
        X = Activation('relu')(X)
        X = Dense(1, name='FC3_Discriminator')(X)
        return Model(inputs = X_input, outputs = X, name='Discriminator')

    def train(self):
        def runDiscriminator(epoch):


        self.writer = tf.summary.FileWriter(self.log_dir)

        self.x_train = []
        for index, i in enumerate([self.filePath]):
            for j in os.listdir(i):
                self.x_train.append(i+j)
        self.x_train = np.array(self.x_train)

        # Adversarial ground truths
        self.valid = -np.ones((self.batch_size, 1))
        self.fake =  np.ones((self.batch_size, 1))
        self.dummy = np.zeros((self.batch_size, 1)) # Dummy gt for gradient penalty

        self.h_g_loss = []
        self.h_d_loss = []
        for epoch in range(self.epochs):
            for _ in range(self.n_discriminator):

                # ---------------------
                #  Train Discriminator
                # ---------------------
                imgs = []
                idx = np.random.randint(0, self.x_train.shape[0], self.batch_size)
                for x in self.x_train[idx]:
                    imgs.append(np.load(x))
                # Rescale -1 to 1
                imgs = np.array(imgs).astype(np.float64)
                imgs = (imgs - (imgs.max()/2)) / (imgs.max()/2)
                imgs = np.expand_dims(imgs, axis=4)

                # Sample generator input
                noise = np.random.normal(0, 1, (self.batch_size, self.latent_dim))
                # Train the discriminator
                d_loss = self.discriminator_model.train_on_batch([imgs, noise], [self.valid, self.fake, self.dummy])

            # ---------------------
            #  Train Generator
            # ---------------------
            g_loss = self.generator_model.train_on_batch(noise, self.valid)
            self.h_g_loss.append(g_loss)
            self.h_d_loss.append(d_loss)

            summary = tf.Summary(value=[
                    tf.Summary.Value(tag="G_loss", simple_value=g_loss),
                    tf.Summary.Value(tag="D_loss_real", simple_value=d_loss[0]),
                    tf.Summary.Value(tag="D_loss_fake", simple_value=d_loss[1]),])
            self.writer.add_summary(summary, epoch)

            # Plot the progress
            print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))

            # If at save interval => save generated image samples
            if epoch % self.sample_interval == 0:
                self.sample_images(epoch)
                self.saveLog()
                N_ge = self.outputsPath+'generator_model_%d.h5' % epoch
                N_di = self.outputsPath+'discriminator_model_%d.h5' % epoch
                self.generator.save(N_ge)
                self.discriminator.save(N_di)
        self.writer.close()

    def sample_images(self, epoch):
        scaler = MinMaxScaler()
        r, c = 8, 8
        noise = np.random.normal(0, 1, (1, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        fig, axs = plt.subplots(r, c, figsize = ((c*self.img_rows)/100, (r*self.img_rows)/100))
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(scaler.fit_transform(gen_imgs[0,cnt,:,:,0]), cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig(self.outputsPath+"brain_%d.png" % epoch)
        plt.close()

    def saveLog(self):
        plt.plot(np.array(self.h_d_loss)[:,0], label='crit_real')
        plt.plot(np.array(self.h_d_loss)[:,1], label='crit_fake')
        plt.plot(self.h_g_loss, label='gen')
        plt.legend()
        plt.grid()
        plt.savefig(self.outputsPath+'plot_loss.png')
        plt.close()

        np.save(self.outputsPath+'h_g_loss', np.array(self.h_g_loss))
        np.save(self.outputsPath+'h_d_loss', np.array(self.h_d_loss))

if __name__ == '__main__':
    wgan = WGANGP(batch_size=1, epochs=30000, load_Weights=True, sample_interval=50)
    wgan.train()

Вот мой тренировочный процесс и результаты enter image description here enter image description here enter image description here

Генерировать результат 1

Генерировать результат 2 enter image description here

...