Изменения в слоях теперь приводят к отключению во время выполнения в переполнении стека - PullRequest
0 голосов
/ 09 марта 2020

Так что я работаю с WGAN. У меня было ужасно с самого начала, и когда я изменил количество фильтров, у меня больше нет проблем с отключением во время выполнения

Текущая модель

'WGAN Model'

import tensorflow as tf  # TF 2.0

#For attempt 8 halving all the filter counts
class Generator(tf.keras.Model):
    def __init__(self, channels=3, method='transpose'):
        super(Generator, self).__init__()
        self.channels = channels
        self.method = method

        self.dense = tf.keras.layers.Dense(256 * 32 * 32, use_bias=False)

        self.reshape = tf.keras.layers.Reshape((32, 32, 256))

        if self.method == 'transpose':
            self.convT_1 = tf.keras.layers.Conv2DTranspose(128, (5, 5), padding='same', use_bias=False)
            self.convT_2 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)
            self.convT_3 = tf.keras.layers.Conv2DTranspose(self.channels, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
        elif self.method == 'upsample':
            self.upsample2d = tf.keras.layers.UpSampling2D()
            self.conv_1 = tf.keras.layers.Conv2D(128, (3, 3), padding='same', use_bias=False)
            self.conv_2 = tf.keras.layers.Conv2D(64, (3, 3), padding='same', use_bias=False)
            self.conv_3 = tf.keras.layers.Conv2D(self.channels, (3, 3), padding='same', use_bias=False, activation='tanh')

        self.batch_norm_1 = tf.keras.layers.BatchNormalization()
        self.batch_norm_2 = tf.keras.layers.BatchNormalization()
        self.batch_norm_3 = tf.keras.layers.BatchNormalization()

        self.leakyrelu_1 = tf.keras.layers.LeakyReLU()
        self.leakyrelu_2 = tf.keras.layers.LeakyReLU()
        self.leakyrelu_3 = tf.keras.layers.LeakyReLU()

    def call(self, inputs, training=True):
        if self.method == 'transpose':
            x = self.dense(inputs)
            x = self.batch_norm_1(x, training)
            x = self.leakyrelu_1(x)

            x = self.reshape(x)

            x = self.convT_1(x)
            x = self.batch_norm_2(x, training)
            x = self.leakyrelu_2(x)

            x = self.convT_2(x)
            x = self.batch_norm_3(x, training)
            x = self.leakyrelu_3(x)

            return self.convT_3(x)

        elif self.method == 'upsample':
            # Replace Conv2DTranspose with UpSampling2D & Conv2D

            x = self.dense(inputs)
            x = self.batch_norm_1(x, training)
            x = self.leakyrelu_1(x)

            x = self.reshape(x)

            x = self.conv_1(x)
            x = self.batch_norm_2(x, training)
            x = self.leakyrelu_2(x)

            x = self.upsample2d(x)
            x = self.conv_2(x)
            x = self.batch_norm_3(x, training)
            x = self.leakyrelu_3(x)

            x = self.upsample2d(x)
            return self.conv_3(x)


class Critic(tf.keras.Model):
    def __init__(self):
        super(Critic, self).__init__()
        self.conv_1 = tf.keras.layers.Conv2D(16, (5, 5), strides=2, padding='same',input_shape=(32,32,3))#64 to 32# It looks like this guy wants an input of ndim=4
        self.conv_2 = tf.keras.layers.Conv2D(32, (5, 5), strides=2, padding='same')#128 to 64

        self.flatten = tf.keras.layers.Flatten()

        self.out = tf.keras.layers.Dense(1)#What is this 1. Should it be 10 cause of 10 classes?

        self.leakyrelu_1 = tf.keras.layers.LeakyReLU()
        self.leakyrelu_2 = tf.keras.layers.LeakyReLU()

        self.dropout_1 = tf.keras.layers.Dropout(0.3)
        self.dropout_2 = tf.keras.layers.Dropout(0.3)

        self.batch_norm_1 = tf.keras.layers.BatchNormalization()

    def call(self, inputs, training=True):
        x = self.conv_1(inputs)
        x = self.leakyrelu_1(x)
        x = self.dropout_1(x, training)

        x = self.conv_2(x)
        x = self.batch_norm_1(x, training)
        x = self.leakyrelu_2(x)
        x = self.dropout_2(x, training)

        x = self.flatten(x)

        return self.out(x)


if __name__ == "__main__":
    pass

В исходном коде был только генератор и крит c с одинаковыми значениями фильтра для генератора и крит c. Понятия не имею, что не так

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...