Так что я работаю с WGAN. У меня было ужасно с самого начала, и когда я изменил количество фильтров, у меня больше нет проблем с отключением во время выполнения
Текущая модель
'WGAN Model'
import tensorflow as tf # TF 2.0
#For attempt 8 halving all the filter counts
class Generator(tf.keras.Model):
def __init__(self, channels=3, method='transpose'):
super(Generator, self).__init__()
self.channels = channels
self.method = method
self.dense = tf.keras.layers.Dense(256 * 32 * 32, use_bias=False)
self.reshape = tf.keras.layers.Reshape((32, 32, 256))
if self.method == 'transpose':
self.convT_1 = tf.keras.layers.Conv2DTranspose(128, (5, 5), padding='same', use_bias=False)
self.convT_2 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)
self.convT_3 = tf.keras.layers.Conv2DTranspose(self.channels, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
elif self.method == 'upsample':
self.upsample2d = tf.keras.layers.UpSampling2D()
self.conv_1 = tf.keras.layers.Conv2D(128, (3, 3), padding='same', use_bias=False)
self.conv_2 = tf.keras.layers.Conv2D(64, (3, 3), padding='same', use_bias=False)
self.conv_3 = tf.keras.layers.Conv2D(self.channels, (3, 3), padding='same', use_bias=False, activation='tanh')
self.batch_norm_1 = tf.keras.layers.BatchNormalization()
self.batch_norm_2 = tf.keras.layers.BatchNormalization()
self.batch_norm_3 = tf.keras.layers.BatchNormalization()
self.leakyrelu_1 = tf.keras.layers.LeakyReLU()
self.leakyrelu_2 = tf.keras.layers.LeakyReLU()
self.leakyrelu_3 = tf.keras.layers.LeakyReLU()
def call(self, inputs, training=True):
if self.method == 'transpose':
x = self.dense(inputs)
x = self.batch_norm_1(x, training)
x = self.leakyrelu_1(x)
x = self.reshape(x)
x = self.convT_1(x)
x = self.batch_norm_2(x, training)
x = self.leakyrelu_2(x)
x = self.convT_2(x)
x = self.batch_norm_3(x, training)
x = self.leakyrelu_3(x)
return self.convT_3(x)
elif self.method == 'upsample':
# Replace Conv2DTranspose with UpSampling2D & Conv2D
x = self.dense(inputs)
x = self.batch_norm_1(x, training)
x = self.leakyrelu_1(x)
x = self.reshape(x)
x = self.conv_1(x)
x = self.batch_norm_2(x, training)
x = self.leakyrelu_2(x)
x = self.upsample2d(x)
x = self.conv_2(x)
x = self.batch_norm_3(x, training)
x = self.leakyrelu_3(x)
x = self.upsample2d(x)
return self.conv_3(x)
class Critic(tf.keras.Model):
def __init__(self):
super(Critic, self).__init__()
self.conv_1 = tf.keras.layers.Conv2D(16, (5, 5), strides=2, padding='same',input_shape=(32,32,3))#64 to 32# It looks like this guy wants an input of ndim=4
self.conv_2 = tf.keras.layers.Conv2D(32, (5, 5), strides=2, padding='same')#128 to 64
self.flatten = tf.keras.layers.Flatten()
self.out = tf.keras.layers.Dense(1)#What is this 1. Should it be 10 cause of 10 classes?
self.leakyrelu_1 = tf.keras.layers.LeakyReLU()
self.leakyrelu_2 = tf.keras.layers.LeakyReLU()
self.dropout_1 = tf.keras.layers.Dropout(0.3)
self.dropout_2 = tf.keras.layers.Dropout(0.3)
self.batch_norm_1 = tf.keras.layers.BatchNormalization()
def call(self, inputs, training=True):
x = self.conv_1(inputs)
x = self.leakyrelu_1(x)
x = self.dropout_1(x, training)
x = self.conv_2(x)
x = self.batch_norm_1(x, training)
x = self.leakyrelu_2(x)
x = self.dropout_2(x, training)
x = self.flatten(x)
return self.out(x)
if __name__ == "__main__":
pass
В исходном коде был только генератор и крит c с одинаковыми значениями фильтра для генератора и крит c. Понятия не имею, что не так