Странное поведение нейтральной сети, написанной на простом JS - PullRequest
0 голосов
/ 06 января 2019

Под впечатлением от возможностей нейронных сетей я решил, что перед использованием любой библиотеки я хочу понять, как они работают. Поэтому я написал простое учебное приложение, в котором использовалась трехслойная сеть с двумя нейронами в каждом. Был холст 400х400. Учитывая координаты x, y мыши над холстом <0; 399>, он должен был дать в качестве результата координату / 400 <0; 1> (таким образом, для 100 300 он должен дать 0,25,0,75).

Тренировка выглядела разумно.

training

Но когда я переключаюсь в режим прогнозирования, сеть выдает одинаковый результат все время для каждой обучающей партии. Это дает одинаковые результаты независимо от того, что вход. training data Затем после дополнительного обучения выходной сигнал меняется, но он остается одинаковым для каждого входа.

Это написано на TypeScript. Вместо того, чтобы вставлять целую страницу веб-обучения, я просто создал учебный сценарий, чтобы вы могли более четко видеть, что происходит.

index.ts

let sigmoid: ActivationFunction = {
    func: (x: number) => (1 / (1 + Math.exp(-x))),
    derivative: (z: number) => {
        return sigmoid.func(z) * (1 - sigmoid.func(z));
    }
};

import Matrix from './matrix';

class NeutralNetwork {
    layers: Array<number>;
    weights: Matrix[];
    biases: Matrix[];
    activation_function: ActivationFunction;
    learning_rate: number;

    constructor(...layers: Array<number>) {
        this.layers = layers;
        this.activation_function = sigmoid;

        //Initialize neural network with random weigths and biases [-1;1]
        this.weights = [];
        for(let i=0; i<this.layers.length - 1; i++){
            this.weights.push(new Matrix(this.layers[i+1], this.layers[i]));
            this.weights[i].randomize();
        }
        this.biases = [];
        for(let i=1; i<this.layers.length; i++){
            this.biases.push(new Matrix(this.layers[i], 1));
            this.biases[i-1].randomize();
        }

        this.setActivationFunction();
        this.setLearningRate();
    }

    feedForward(originalInput: Array<number>): Array<number> {
        if(originalInput.length != this.layers[0]) throw new Error("corrupt input data");

        let input : Matrix = Matrix.createFromArray(originalInput);
        for(let i = 0; i < this.layers.length - 1; i++){
            let output = Matrix.multiply(this.weights[i], input);
            output.add(this.biases[i]);
            output.map(this.activation_function.func);
            input = output;
        }

        return input.toArray();
    }

    train(originalInput: Array<number>, originalTarget: Array<number>) {
        if(originalInput.length != this.layers[0]) throw new Error("corrupt training data");
        if(originalTarget.length != this.layers[this.layers.length - 1]) throw new Error("corrupt training data");

        let outputs : Matrix[] = [];
        let input : Matrix = Matrix.createFromArray(originalInput);
        for(let i = 0; i < this.layers.length - 1; i++){
            let output = Matrix.multiply(this.weights[i], input);
            output.add(this.biases[i]);
            output.map(this.activation_function.func);
            input = output;
            outputs.push(output);
        }

        let target = Matrix.createFromArray(originalTarget);
        let errors = Matrix.subtract(target, outputs[this.layers.length - 2]);


        for(let i = this.layers.length - 2; i>=0; i--){
            let gradients = Matrix.map(outputs[i], this.activation_function.derivative);
            gradients.multiply(errors);
            gradients.multiply(this.learning_rate);

            let outputsOfLayerBeforeTransposed = Matrix.transpose(i > 0 ? outputs[i-1] : Matrix.createFromArray(originalInput));
            let deltas = Matrix.multiply(gradients, outputsOfLayerBeforeTransposed);

            this.weights[i].add(deltas);
            this.biases[i].add(gradients);

            let weightsTransposed = Matrix.transpose(this.weights[i]);
            errors = Matrix.multiply(weightsTransposed, errors);
        }

        return outputs[outputs.length - 1].toArray();

    }



    setActivationFunction(activationFunction = sigmoid) {
        this.activation_function = activationFunction;
    }
    setLearningRate(learning_rate = 0.1) {
        this.learning_rate = learning_rate;
    }
}


interface ActivationFunction {
    func(x: number): number;
    derivative(x: number): number;
}

export = NeutralNetwork;

training.ts

let NN = require('./index');

let n = new NN(2,2,2);

let data = generateTrainingData();
data.forEach(d => n.train(d.i, d.o));

//check how well is it trained
let index = 0
let t = setInterval(()=>{
    let pred = n.feedForward(data[index].i);
    console.log(`PREDICTED - ${pred} EXPECTED = ${data[index].o} COST - ${Math.pow(pred[0]-data[index].o[0],2)+Math.pow(pred[1]-data[index].o[1],2)}`)
    if(index++ == 1000) clearInterval(t);
}, 500);

function generateTrainingData(){
    let data = [];
    for(let i=0;i<1000;i++){
        let x = Math.floor(Math.random() * 400);
        let y = Math.floor(Math.random() * 400);
        data.push({
            i : [x,y],
            o : [x/400, y/400]
        })
    }

    return data;
}

matrix.ts

export default class Matrix {
    rows;
    columns;
    data: Array<Array<number>>;

    constructor(rows, columns) {
        this.rows = rows;
        this.columns = columns;
        this.data = new Array(this.rows).fill().map(() => Array(this.columns).fill(0));
    }

    static map(matrix, f) : Matrix{
        let m = new Matrix(matrix.rows, matrix.columns);
        m.map((v,i,j) => f(matrix.data[i][j], i, j));
        return m;
    }

    map(f) {
        for (let i = 0; i < this.rows; i++) {
            for (let j = 0; j < this.columns; j++) {
                this.data[i][j] = f(this.data[i][j], i, j);
            }
        }
    }

    randomize() {
        this.map(() => Math.random() * 2 - 1);
    }

    add(n) {
        if (n instanceof Matrix) {
            if (this.rows !== n.rows || this.columns !== n.columns) {
                throw new Error('Size of both matrices must match!');
            }
            return this.map((v, i, j) => v + n.data[i][j]);
        } else {
            return this.map(v => v + n);
        }
    }



    static subtract(a, b) : Matrix{
        if (a.rows !== b.rows || a.columns !== b.columns) {
            throw new Error('Size of both matrices must match!');
        }

        let m = new Matrix(a.rows, a.columns);
        m.map((_, i, j) => a.data[i][j] - b.data[i][j]);
        return m;
    }

    static multiply(a, b) {

        if (a.columns !== b.rows) {
            throw new Error('a.columns !== b.rows');
        }

        let m = new Matrix(a.rows, b.columns)
        m.map((_, i, j) => {
            let sum = 0;
            for (let k = 0; k < a.cols; k++) {
                sum += a.data[i][k] * b.data[k][j];
            }
            return sum;
        });

        return m;
    }
    multiply(n) {
        if (n instanceof Matrix) {
            if (this.rows !== n.rows || this.columns !== n.columns) {
                throw new Error('Size of both matrices must match!');
            }
            return this.map((v, i, j) => v * n.data[i][j]);
        } else {
            return this.map(v => v * n);
        }
    }
    toArray() {
        let arr = [];
        for (let i = 0; i < this.rows; i++) {
            for (let j = 0; j < this.columns; j++) {
                arr.push(this.data[i][j]);
            }
        }
        return arr;
    }

    static transpose(matrix) : Matrix {
        let m = new Matrix(matrix.columns, matrix.rows)
        m.map((_, i, j) => matrix.data[j][i]);
        return m;
    }

    static createFromArray(arr): Matrix {
        let m = new Matrix(arr.length, 1);
        m.map((v, i) => arr[i]);
        return m;
    }
}

Я не совсем уверен, в чем причина этого. Я пытался отладить это уже несколько дней, но я думаю, что мой недостаток опыта не позволяет мне увидеть проблему здесь. Большое спасибо за вашу помощь.

1 Ответ

0 голосов
/ 06 января 2019

Ошибка в Matrix.multiply методе класса. Это должно быть a.columns, а не a.cols. Из-за этого gradients и deltas не обновляются должным образом.

...