Выполняет ли вручную мультиклассовую логистическую регрессию softmax - правильна ли эта формула обновления весов? - PullRequest
0 голосов
/ 03 сентября 2018

Для проекта я пытаюсь закодировать мультиклассовую логистическую регрессию в Java. Я следовал этому уроку , чтобы попытаться понять его, но у меня возникают проблемы с программированием обучения модели, особенно с обновлением весов, в этом бите:

enter image description here

Меня действительно смущает размерность каждого компонента, например сигма-сумма (она многомерна?) И весовые коэффициенты.

Это код на данный момент (только самые важные биты)

/** the weight to learn */
private static double[][] weights;
static double sum[][] = new double[8][10];

public static class Instance {
    public int label;
    public double[] x;

    public Instance(int label, double[] x) {
        this.label = label;
        this.x = x;
    } //each instance contains a label int and a double array, x, containing the data itself
}

public void train(List<Instance> instances) {
     for (int n=0; n<ITERATIONS; n++) {    
        for (int j = 0; j < 8; j++) {   //resetting the sum as 0 for each iteration
                sum[j] = 0;
        }
        for (int i=0; i<instances.size(); i++) { //sets up the sigma function
            double[] x = instances.get(i).x;
            double[] predicted = classify(x);
            double[] label = onehotencoder(instances.get(i).label);

            double[] difference = new double[predicted.length];
            for (int a = 0; a < difference.length; a++) {
                difference[a] = label[a] - predicted[a];
            }

            double[] result = {0,0,0,0,0,0,0,0};
            for (int p = 0; p< result.length; p++) {
                for (int q = 0; q < result.length; q++) {
                    result[p] += x[p] * difference[q]; //dot product multiplying the difference by x (is this right??)
                }
            }

            for (int j = 0; j < 8; j++) { //doing the sigma summation 
                    sum[j] += result[j];
                }
            }

            for (int k=0; k<10; k++) { //updating weights with formula
                for (int j = 0; j < 8; j++) {
                weights[j][k] = weights[j][k] - rate * ((-1/instances.size()) * sum[j]);
            }
        }

    }
}

public double[] classify(double[] x) {
    double[] logit = new double[10];
    for (int j=0; j < 10; j++) {
        for (int i=0; i < 8;i++)  {
            logit[j] = weights[i][j] * x[i];
        }
    }
    double[] result = softmax(logit);
    return result;
}

softmax, standardisation functions etc. left out

Я почти уверен, что сумма с суммой и весом неверна, но я понятия не имею, что я делаю! Любая помощь будет очень признательна.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...