Как я могу применить медианный фильтр для этого генератора? - PullRequest
1 голос
/ 26 января 2020

Я хочу применить медианный фильтр для этого кода, чтобы уменьшить шум на выходе генератора. Я не знаю, куда мне добавить медианный фильтр, после tanh активации или перед ним или в другом месте? Какой подходящий код медианного фильтра соответствует следующему коду?

def generator(n_samples, noise=None, use_bn=False,  net_dim=64, 
        output_dim=64,is_training=False, latent_dim=128, stats_iter=None):
    if noise is None:
        noise = tf.random_normal([n_samples, latent_dim])

    output = lib.ops.linear.Linear('Generator.Input', latent_dim,
                              4 * 4 * 4 * net_dim, noise)
    output = tf.nn.relu(output)
    output = tf.reshape(output, [-1, 4, 4, 4 * net_dim])
    output = lib.ops.deconv2d.Deconv2D('Generator.2', 4 * net_dim, 2 * net_dim, 5, output)
    output = tf.nn.relu(output)
    output = lib.ops.deconv2d.Deconv2D('Generator.3', 2 * net_dim,  net_dim, 5, output)
    output = tf.nn.relu(output)
    output = lib.ops.deconv2d.Deconv2D('Generator.5',  net_dim, net_dim, 5, output)
    output = lib.ops.deconv2d.Deconv2D('Generator.6', net_dim, 3, 5, output)
    output = tf.tanh(output)
    return output


class DefGAN(DefGANBase):
    def _build_generator(self):

        self.generator_fn = lambda z=None, is_training=self.is_training: \
            generator(self.batch_size,
                      use_bn=self.use_bn,
                      net_dim=self.net_dim,
                      is_training=is_training,
                      latent_dim=self.latent_dim,
                      output_dim=self.image_dim,
                      noise=z,
                      stats_iter=self.global_step)

Я применил этот медианный фильтр:

 self.fixed_noise_samples = self.generator_fn(self.fixed_noise,  
 is_training=False)   
 filt_length=3
 edges = filt_length// 2
 # convert to 4D, where data is in 3rd dim (e.g. data[0,0,:,0]
 exp_data = tf.expand_dims(tf.expand_dims(self.fixed_noise_samples, 0), -1)
 # get rolling window
 wins= tf.image.extract_patches(images=exp_data, sizes=[1, filt_length, 1, 1],
             strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='VALID')
 # get median of each window
 wins = tf.math.top_k(wins, k=2)[0][0, :, :, edges]
 # Concat edges
 self.fixed_noise_samples=tf.concat((self.fixed_noise_samples[:edges, :], 
 wins, data[-edges:, :]), 0)

, но я получаю эту ошибку:

 ValueError: Shape must be rank 4 but is rank 6 for 'ExtractImagePatches' (op: 
 'ExtractImagePatches') with input shapes: [1,128,64,64,3,1].
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...