Я хочу применить медианный фильтр для этого кода, чтобы уменьшить шум на выходе генератора. Я не знаю, куда мне добавить медианный фильтр, после tanh
активации или перед ним или в другом месте? Какой подходящий код медианного фильтра соответствует следующему коду?
def generator(n_samples, noise=None, use_bn=False, net_dim=64,
output_dim=64,is_training=False, latent_dim=128, stats_iter=None):
if noise is None:
noise = tf.random_normal([n_samples, latent_dim])
output = lib.ops.linear.Linear('Generator.Input', latent_dim,
4 * 4 * 4 * net_dim, noise)
output = tf.nn.relu(output)
output = tf.reshape(output, [-1, 4, 4, 4 * net_dim])
output = lib.ops.deconv2d.Deconv2D('Generator.2', 4 * net_dim, 2 * net_dim, 5, output)
output = tf.nn.relu(output)
output = lib.ops.deconv2d.Deconv2D('Generator.3', 2 * net_dim, net_dim, 5, output)
output = tf.nn.relu(output)
output = lib.ops.deconv2d.Deconv2D('Generator.5', net_dim, net_dim, 5, output)
output = lib.ops.deconv2d.Deconv2D('Generator.6', net_dim, 3, 5, output)
output = tf.tanh(output)
return output
class DefGAN(DefGANBase):
def _build_generator(self):
self.generator_fn = lambda z=None, is_training=self.is_training: \
generator(self.batch_size,
use_bn=self.use_bn,
net_dim=self.net_dim,
is_training=is_training,
latent_dim=self.latent_dim,
output_dim=self.image_dim,
noise=z,
stats_iter=self.global_step)
Я применил этот медианный фильтр:
self.fixed_noise_samples = self.generator_fn(self.fixed_noise,
is_training=False)
filt_length=3
edges = filt_length// 2
# convert to 4D, where data is in 3rd dim (e.g. data[0,0,:,0]
exp_data = tf.expand_dims(tf.expand_dims(self.fixed_noise_samples, 0), -1)
# get rolling window
wins= tf.image.extract_patches(images=exp_data, sizes=[1, filt_length, 1, 1],
strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='VALID')
# get median of each window
wins = tf.math.top_k(wins, k=2)[0][0, :, :, edges]
# Concat edges
self.fixed_noise_samples=tf.concat((self.fixed_noise_samples[:edges, :],
wins, data[-edges:, :]), 0)
, но я получаю эту ошибку:
ValueError: Shape must be rank 4 but is rank 6 for 'ExtractImagePatches' (op:
'ExtractImagePatches') with input shapes: [1,128,64,64,3,1].