При обратном распространении нейронной сети все веса, кроме последнего, будут одинаковыми - PullRequest
0 голосов
/ 19 марта 2020

Я делаю авто-кодировщик, который сжимает изображения. Входной уровень имеет 6144 узла (64 * 32 * 3), следующий уровень имеет 2048 узлов, скрытое пространство имеет 1024 узла, а часть декодера имеет обратную структуру (2048, 6144). Это просто простая сеть прямой связи. Я пытаюсь использовать обратное распространение для обучения сети. Последняя группа весов (те, которые подключены к выходному слою) обновляются нормально, но другие весы нет. Они меняются, но все они меняются в унисон. Я попытался установить веса на случайные начальные значения, изменить размер скрытого пространства (сначала это было 64 узла) и настроить другие вещи, такие как скорость обучения и моя функция активации, но ничего не помогло. Это заставляет мой авто-кодер просто выводить среднее изображение, а не то, что его вводит. Как я могу это исправить? Вот мой код:

namespace MinecraftSkinGenerator{
class Program
{
    //--------------------PUBLIC VARIABLES--------------------
    //Network Path
    public static string NetworkValuesPath = @"D:\minecraft-skins\Test.NeuralNetwork";
    //Skin Folder for input images
    public static string SkinFolder = @"D:\minecraft-skins\skins\Small_Skins";
    //Skin Folder for output images
    public static string OutputFolder = @"D:\minecraft-skins\Output";
    //Neural Network
    public static Network network;
    public static int LatentSize = 1024;
    public static float LearningRate = 0.01f;
    //For calculating average loss
    public static float TotalLoss = 0;
    public static int TotalSkinsSeen = 0;

    static void Main(string[] args)
    {
        //-----------------NETWORK SETUP---------------
        //Setup and Reset the network
        network = new Network();
        network.Reset();
        //Load from the file
        network.Load(NetworkValuesPath);
        //-----------------START LEARNING----------------
        SkinCopyAttempt(0);
        Console.ReadLine();
    }
    public static void SkinCopyAttempt(int id)
    {
        //----------------------------------GRAB THE INPUT----------------------
        //Load the skin
        Image SkinImage = Image.FromFile(Path.Combine(SkinFolder, id + ".jpg"));
        //Convert it to Bitmap
        Bitmap SkinBitmap = (Bitmap) SkinImage;
        //Convert it to array
        float[] Skin = BitmapToArray(SkinBitmap);
        //Dispose the image and bitmap
        SkinImage.Dispose();
        SkinBitmap.Dispose();
        //----------------------------------ENCODER-----------------------------
        //First Layer
        float[] EncoderLayer = WheightCalculation(Skin, network.EncoderWheight0);
        //Latent Space
        float[] LatentSpace = WheightCalculation(EncoderLayer, network.EncoderWheight1);
        //----------------------------------DECODER-----------------------------
        //First Layer
        float[] DecoderLayer = WheightCalculation(LatentSpace, network.DecoderWheight0);
        //Output
        float[] Output = WheightCalculation(DecoderLayer, network.DecoderWheight1);
        //----------------------------------CALCULATE LOSS---------------------
        //loss is the distance between input and output
        float loss = 0;
        //(this will be used in "tweak wheights')
        float[] outputDirection = new float[64 * 32 * 3];
        //Run throught every byte of skin
        for (int i = 0; i < Skin.Length; i++)
        {
            //Calculate differnce between input and output, srqaure it, then add that to loss
            loss += (float) Math.Pow(Skin[i] - Output[i], 2);
            //(this will be used in "tweak wheights') Does this output need to be larger or smaller?
            outputDirection[i] = (float)Math.Atan(Skin[i] - Output[i]);
        }
        //sqrt loss
        loss = (float) Math.Sqrt(loss);
        //Convert to percentage out of maximum loss: 19987.8363011
        loss = loss / 199.878363011f;
        //Calculate average loss
        TotalLoss += loss;
        TotalSkinsSeen++;
        float avLoss = TotalLoss / TotalSkinsSeen;
        //Write the loss to the console
        Console.WriteLine("Skin " + id + " Loss: " + loss + "%" + " Average Loss: " + avLoss + "%");
        //----------------------------------TWEAK WHEIGHTS----------------------
        //DecoderWheights1
        float[] DecoderDir = TweakWheights(DecoderLayer, network.DecoderWheight1, outputDirection);
        //DecoderWheights0
        float[] LatentDir = TweakWheights(LatentSpace, network.DecoderWheight0, DecoderDir);
        //EncoderWheights1
        float[] EncoderDir = TweakWheights(EncoderLayer, network.EncoderWheight1, LatentDir);
        //EncoderWheights0
        TweakWheights(Skin, network.EncoderWheight0, EncoderDir);
        //----------------------------------REPEAT------------------------------
        //Is this the 100th skin?
        if (id%100 == 0)
        {

            //Output skin
            FileStream fs = File.Create(Path.Combine(OutputFolder, id + "_1Input.jpg"));
            Bitmap bitI = ArrayToBitmap(Skin, 64, 32, true);
            bitI.Save(fs, System.Drawing.Imaging.ImageFormat.Png);
            fs.Close();
            bitI.Dispose();

            //Output encoder
            fs = File.Create(Path.Combine(OutputFolder, id + "_2Encoder.jpg"));
            Bitmap bitE = ArrayToBitmap(EncoderLayer, 64, 32, false);
            bitE.Save(fs, System.Drawing.Imaging.ImageFormat.Png);
            fs.Close();
            bitE.Dispose();
            //Output latent
            fs = File.Create(Path.Combine(OutputFolder, id + "_3Latent.jpg"));
            Bitmap bitL = ArrayToBitmap(LatentSpace, (int)Math.Sqrt(LatentSize), (int)Math.Sqrt(LatentSize), false);
            bitL.Save(fs, System.Drawing.Imaging.ImageFormat.Png);
            fs.Close();
            bitL.Dispose();
            //Output decoder
            fs = File.Create(Path.Combine(OutputFolder, id + "_4Decoder.jpg"));
            Bitmap bitD = ArrayToBitmap(DecoderLayer, 64, 32, false);
            bitD.Save(fs, System.Drawing.Imaging.ImageFormat.Png);
            fs.Close();
            bitD.Dispose();

            //Output skin
            fs = File.Create(Path.Combine(OutputFolder, id + "_5Output.jpg"));
            Bitmap bitS = ArrayToBitmap(Output, 64, 32, true);
            bitS.Save(fs, System.Drawing.Imaging.ImageFormat.Png);
            fs.Close();
            bitS.Dispose();
            //Save the network
            network.Save(NetworkValuesPath);
        }
        if(id == 2326)
            SkinCopyAttempt(0);
        else
            SkinCopyAttempt(id+1);
    }
    public static float[] BitmapToArray(Bitmap bit)
    {
        //make array
        float[] floats = new float[bit.Width * bit.Height * 3];
        //used to tell what byte we are currently setting
        int b = 0;
        //loop through every pixel
        for (int x = 0; x < bit.Width; x++)
        {
            for (int y = 0; y < bit.Height; y++)
            {
                //Copy each color to the array
                floats[b] = bit.GetPixel(x, y).R;
                floats[b + 1] = bit.GetPixel(x, y).G;
                floats[b + 2] = bit.GetPixel(x, y).B;
                b += 3;
            }
        }
        //reurn
        return floats;
    }
    public static Bitmap ArrayToBitmap(float[] floats, int w, int h, bool c)
    {
        //make map
        Bitmap bit = new Bitmap(w, h);
        //used to tell what byte we are currently reading
        int b = 0;
        //loop through every pixel
        for (int x = 0; x < bit.Width; x++)
        {
            for (int y = 0; y < bit.Height; y++)
            {
                //Copy each color to the array
                if (c)
                {
                    //Set the pixel. Wierd code is to clamp the float between 0 and 255 and then convert it to an int
                    bit.SetPixel(x, y, Color.FromArgb(255, (int)Math.Max(Math.Min(floats[b], 255), 0), (int)Math.Max(Math.Min(floats[b + 1], 255), 0), (int)Math.Max(Math.Min(floats[b + 2], 255), 0)));
                    b += 3;
                }
                else
                {
                    //Set the pixel. Wierd code is to clamp the float between 0 and 255 and then convert it to an int
                    bit.SetPixel(x, y, Color.FromArgb(255, (int)Math.Max(Math.Min(floats[b], 255), 0), (int)Math.Max(Math.Min(floats[b], 255), 0), (int)Math.Max(Math.Min(floats[b], 255), 0)));
                    b ++;
                }
            }
        }
        //reurn
        return bit;
    }
    public static float[] WheightCalculation(float[] input, float[][] wheights)
    {
        //Make new array
        float[] newLayer = new float[wheights[0].Length];
        //Run through all elements of the new layer array
        for (int i = 0; i < newLayer.Length; i++)
        {
            //Run through all of the inputs
            for (int i2 = 0; i2 < input.Length; i2++)
            {
                newLayer[i] += input[i2] * wheights[i2][i];
            }
            //Activation Function (RelU)
            newLayer[i] = (float)((Math.Atan((double)newLayer[i] / 1000f) + (Math.PI / 2)) * 81.1690209769f);
        }
        //return the layer
        return newLayer;
    }
    public static float[] TweakWheights(float[] input, float[][] wheights, float[] direction)
    {
        //Setup output. The output is the tweaking direction for the previous layer
        float[] newDirection = new float[wheights.Length];
        //loop through every input
        for (int i = 0; i < wheights.Length; i++)
        {
            //use this to keep track of what direction the input needs to go
            newDirection[i] = 0;
            //loop through every direction
            for (int i2 = 0; i2 < direction.Length; i2++)
            {
                //tweak wheight by the amount it effects network aka direction
                wheights[i][i2] += LearningRate * (float)Math.Atan(direction[i2] * (input[i] + 0.00000001f));
                //what direction does this input need to go?
                newDirection[i] += Math.Abs(input[i] * wheights[i][i2]) * direction[i2];
                //Does this input need to go up or down?
                /*
                //Does this wheight need to go up or down?
                if (direction[i2])
                {
                    //Up
                    wheights[i][i2] += LearningRate;
                    //Is this wheight positize? (>128) Or negative? (<128)
                    if (wheights[i][i2] > 128)
                    {
                        //It would be better if the input was larger
                        dir += (int)Math.Round(input[i] * wheights[i][i2]);
                    }
                    else if (wheights[i][i2] < 128)
                    {
                        //It would be better if the input was smaller
                        dir -= (int)Math.Round(input[i] * wheights[i][i2]);
                    }
                }
                else
                {
                    //down
                    wheights[i][i2] -= LearningRate;
                    //Is this wheight positize? (>=128)
                    if (wheights[i][i2] >= 128)
                    {
                        //It would be better if the input was smaller
                        dir -= (int)Math.Round(input[i] * wheights[i][i2]);
                    }
                    else if (wheights[i][i2] < 128)
                    {
                        //It would be better if the input was larger
                        dir += (int)Math.Round(input[i] * wheights[i][i2]);
                    }
                }
                */
            }
            //Would it be better if this input was larger or smaller?
            newDirection[i] = (float)Math.Atan(newDirection[i]);
            //Console.Write(newDirection[i]);
        }
        //Return new direction
        return newDirection;
    }
}
class Network
{
    public float[][] EncoderWheight1;
    public float[][] EncoderWheight0;
    //public float[] EncoderBias;
    public float[][] DecoderWheight1;
    public float[][] DecoderWheight0;
    //public float[] DecoderBias;

    public static byte[] FloatToByte(float[][] f)
    {
        byte[] b = new byte[f.Length * f[0].Length * 4];
        for (int i = 0; i < f.Length; i++)
        {
            Buffer.BlockCopy(f[i], 0, b, i * f[i].Length * 4, f[i].Length * 4);
        }
        return b;
    }
    public static float[][] ByteToFloat(byte[] b, int x, int y)
    {
        float[][] f = new float[x][];
        for (int i = 0; i < x; i++)
        {
            f[i] = new float[y];
            Buffer.BlockCopy(b, y*i*4, f[i], 0, y*4);
        }
        return f;
    }
    public void Random()
    {
        /*
        //1. EncoderBias
        EncoderBias = new float[(64 * 32)];
        //2. DecoderBias
        DecoderBias = new byte[(64 * 32)];
        */
        //3. EncoderWheight0
        var random = new Random();
        EncoderWheight0 = new float[(64 * 32 * 3)][];
        for (int i = 0; i < (64 * 32 * 3); i++)
        {
            EncoderWheight0[i] = new float[(64 * 32)];
            for (int i2 = 0; i2 < 64 * 32; i2++)
            {
                EncoderWheight0[i][i2] = (float) random.Next(-1000, 1000) / 10000f;
            }
        }
        //4. DecoderWheight0
        DecoderWheight0 = new float[1024][];
        for (int i = 0; i < 1024; i++)
        {
            DecoderWheight0[i] = new float[(64 * 32)];
            for (int i2 = 0; i2 < 64 * 32; i2++)
            {
                DecoderWheight0[i][i2] = random.Next(-1000, 1000) / 10000f;
            }
        }
        //5. EncoderWheight1
        EncoderWheight1 = new float[(64 * 32)][];
        for (int i = 0; i < (64 * 32); i++)
        {
            EncoderWheight1[i] = new float[1024];
            for (int i2 = 0; i2 < 1024; i2++)
            {
                EncoderWheight1[i][i2] = random.Next(-1000, 1000) / 10000f;
            }
        }
        //6. DecoderWheight1
        DecoderWheight1 = new float[(64 * 32)][];
        for (int i = 0; i < 64 * 32; i++)
        {
            DecoderWheight1[i] = new float[(64 * 32 * 3)];
            for (int i2 = 0; i2 < (64 * 32 * 3); i2++)
            {
                DecoderWheight1[i][i2] = random.Next(-1000, 1000) / 10000f;
            }
        }
    }
    public void Reset()
    {
        /*
        //1. EncoderBias
        EncoderBias = new byte[(64 * 32)];
        //2. DecoderBias
        DecoderBias = new byte[(64 * 32)];
        */
        //3. EncoderWheight0
        EncoderWheight0 = new float[(64 * 32 * 3)][];
        for (int i = 0; i < (64 * 32 * 3); i++)
        {
            EncoderWheight0[i] = new float[(64 * 32)];
            for (int i2 = 0; i2 < 64 * 32; i2++)
            {
                EncoderWheight0[i][i2] = 0.0001f;
            }
        }
        //4. DecoderWheight0
        DecoderWheight0 = new float[1024][];
        for (int i = 0; i < 1024; i++)
        {
            DecoderWheight0[i] = new float[(64 * 32)];
            for (int i2 = 0; i2 < 64 * 32; i2++)
            {
                DecoderWheight0[i][i2] = 0.0001f;
            }
        }
        //5. EncoderWheight1
        EncoderWheight1 = new float[(64 * 32)][];
        for (int i = 0; i < (64 * 32); i++)
        {
            EncoderWheight1[i] = new float[1024];
            for (int i2 = 0; i2 < 1024; i2++)
            {
                EncoderWheight1[i][i2] = 0.0001f;
            }
        }
        //6. DecoderWheight1
        DecoderWheight1 = new float[(64*32)][];
        for (int i = 0; i < 64 * 32; i++)
        {
            DecoderWheight1[i] = new float[(64 * 32 * 3)];
            for (int i2 = 0; i2 < (64*32*3); i2++)
            {
                DecoderWheight1[i][i2] = 0.0001f;
            }
        }
    }
    public void Save(string path)
    {
        //Write the data
        FileStream fs = File.Open(path, FileMode.Open);
        /*
        //Encoder Bias
        byte[] ByteEncoderBias = new byte[EncoderBias.Length*4];
        Buffer.BlockCopy(EncoderBias, 0, ByteEncoderBias, 0, ByteEncoderBias.Length);
        fs.Write(ByteEncoderBias, 0, EncoderBias.Length);
        //Decoder Bias
        fs.Write(DecoderBias, 0, DecoderBias.Length);
        */
        //Encoder Wheight0
        fs.Write(FloatToByte(EncoderWheight0), 0, 64 * 32 * 3 * 64 * 32 * 4);
        //Decoder Wheight0
        fs.Write(FloatToByte(DecoderWheight0), 0, 1024 * 64 * 32 * 4);
        //Encoder Wheight1
        fs.Write(FloatToByte(EncoderWheight1), 0, 1024 * 32 * 64 * 4);
        //Decoder Wheight1
        fs.Write(FloatToByte(DecoderWheight1), 0, 64 * 32 * 3 * 64 * 32 * 4);
        //Close
        fs.Close();
    }
    public void Load(string path)
    {
        //Load the data
        FileStream fs = File.Open(path, FileMode.Open);
        /*
        //Encoder Bias
        fs.Read(EncoderBias, 0, EncoderBias.Length);
        //Decoder Bias
        fs.Read(DecoderBias, 0, DecoderBias.Length);
        */
        //------Encoder Wheight0
        //read from file
        byte[] b = new byte[64 * 32 * 3 * 64 * 32 * 4];
        fs.Read(b, 0, 64 * 32 * 3 * 64 * 32 * 4);
        //convert to float
        EncoderWheight0 = ByteToFloat(b, 64 * 32 * 3, 64 * 32);
        //------Decoder Wheight0
        //read from file
        b = new byte[1024 * 64 * 32 * 4];
        fs.Read(b, 0, 1024 * 64 * 32 * 4);
        //convert to float
        DecoderWheight0 = ByteToFloat(b, 1024, 64 * 32);
        //------Encoder Wheight1
        //read from file
        b = new byte[1024 * 64 * 32 * 4];
        fs.Read(b, 0, 1024 * 64 * 32 * 4);
        //convert to float
        EncoderWheight1 = ByteToFloat(b, 64 * 32, 1024);
        //------Decoder Wheight1
        //read from file
        b = new byte[64 * 32 * 3 * 64 * 32 * 4];
        fs.Read(b, 0, 64 * 32 * 3 * 64 * 32 * 4);
        //convert to float
        DecoderWheight1 = ByteToFloat(b, 64 * 32, 64 * 32 * 3);
        //------Close
        fs.Close();
    }
}

}

...