Правильная реализация логистической регрессии - PullRequest
0 голосов
/ 24 апреля 2019

Я создаю приложение для Android, которое прогнозирует усталость на основе переменной Score. В качестве модели прогнозирования я использую логистическую регрессию, точнее, классы * @author tpeng * @author Matthieu Labas, которые я нашел здесь: https://github.com/tpeng/logistic-regression/blob/master/src/Logistic.java Проблема в том, что вероятности вывода, которые я получаю, кажутся очень неправильными.

Я попытался использовать исходный набор данных, предоставленный с кодом с 5 переменными предиктора, и результаты казались намного лучше. Я также пытался поиграть с темпами обучения и итерациями, но безуспешно.

    /**
     * Performs simple logistic regression.
     * User: tpeng
     * Date: 6/22/12
     * Time: 11:01 PM
     *
     * @author tpeng
     * @author Matthieu Labas
     */
    public class LogisticRegression {
    private static Context mContext;
        /** the learning rate */
        private double rate;

/** the weight to learn */
private double[] weights;

/** the number of iterations */
private int ITERATIONS = 1000;

public LogisticRegression(int n, Context context) {
    this.rate = 0.0001;
    weights = new double[n];
    this.mContext=context;
}

private static double sigmoid(double z) {
    return 1.0 / (1.0 + Math.exp(-z));
}

public void train(List<Instance> instances) {
    for (int n=0; n<ITERATIONS; n++) {
        double lik = 0.0;
        for (int i=0; i<instances.size(); i++) {
            int[] x = instances.get(i).x;
            double predicted = classify(x);
            int label = instances.get(i).label;
            for (int j=0; j<weights.length; j++) {
                weights[j] = weights[j] + rate * (label - predicted) * x[j];
            }
            // not necessary for learning
           // lik += label * Math.log(classify(x)) + (1-label) * Math.log(1- classify(x));
        }
  //      System.out.println("iteration: " + n + " " + Arrays.toString(weights) + " mle: " + lik);
    }
}

public double classify(int[] x) {
    double logit = .0;
    for (int i=0; i<weights.length;i++)  {
        logit += weights[i] * x[i];
    }
    return sigmoid(logit);
}

public static class Instance {
    public int label;
    public int[] x;

    public Instance(int label, int[] x) {
        this.label = label;
        this.x = x;
    }
}

public static List<Instance> readDataSet(String path) {
    List<Instance> dataset = new ArrayList<Instance>();
    Scanner scanner = null;
    AssetManager am = mContext.getAssets();
    try {
        InputStream is = am.open(path);
        scanner = new Scanner(new InputStreamReader(is));
        while(scanner.hasNextLine()) {
            String line = scanner.nextLine();
            if (line.startsWith("#")) {
                continue;
            }
            String[] columns = line.split("\\s+");

            // skip first column and last column is the label
            int i = 1;
            int[] data = new int[columns.length-2];
            for (i=1; i<columns.length-1; i++) {
                data[i-1] = Integer.parseInt(columns[i]);
            }
            int label = Integer.parseInt(columns[i]);
            Instance instance = new Instance(label, data);
            dataset.add(instance);
        }
    }catch (IOException e) {
        e.printStackTrace();
    } finally {
        if (scanner != null)
            scanner.close();
    }
    return dataset;
}


    }

Моя основная деятельность просто вызывает logisticRegression.classify с n = 1

Я использую самодельный набор данных, который содержит оценки от 0 до 10, и я ожидаю, что очень низкая вероятность усталости близка к 0 и очень высока, когда близка к 10. Независимо от баллов, с которыми я тестирую, я получаю вероятность около 0,5, когда близко к 0 и 0,7, когда близко к 10. Я дважды проверил с помощью внешнего инструмента логистической регрессии и того же набора данных, и результат, который я получаю, намного лучше. Я просмотрел код и сам не нашел никаких ошибок, но он не работает должным образом. Вот также часть моего набора данных, где второй столбец - оценка, а третий - метка:

TestData

1 10 1

2 9 1

3 10 1

4 9 1

5 5 0

6 4 0

7 3 0

8 2 0

9 10 1

10 10 1

11 7 1

12 6 1

13 6 0

14 5 0

15 10 1

16 10 1

17 9 1

18 2 0

19 1 0

...