Я создаю приложение для Android, которое прогнозирует усталость на основе переменной Score. В качестве модели прогнозирования я использую логистическую регрессию, точнее, классы * @author tpeng * @author Matthieu Labas, которые я нашел здесь: https://github.com/tpeng/logistic-regression/blob/master/src/Logistic.java
Проблема в том, что вероятности вывода, которые я получаю, кажутся очень неправильными.
Я попытался использовать исходный набор данных, предоставленный с кодом с 5 переменными предиктора, и результаты казались намного лучше. Я также пытался поиграть с темпами обучения и итерациями, но безуспешно.
/**
* Performs simple logistic regression.
* User: tpeng
* Date: 6/22/12
* Time: 11:01 PM
*
* @author tpeng
* @author Matthieu Labas
*/
public class LogisticRegression {
private static Context mContext;
/** the learning rate */
private double rate;
/** the weight to learn */
private double[] weights;
/** the number of iterations */
private int ITERATIONS = 1000;
public LogisticRegression(int n, Context context) {
this.rate = 0.0001;
weights = new double[n];
this.mContext=context;
}
private static double sigmoid(double z) {
return 1.0 / (1.0 + Math.exp(-z));
}
public void train(List<Instance> instances) {
for (int n=0; n<ITERATIONS; n++) {
double lik = 0.0;
for (int i=0; i<instances.size(); i++) {
int[] x = instances.get(i).x;
double predicted = classify(x);
int label = instances.get(i).label;
for (int j=0; j<weights.length; j++) {
weights[j] = weights[j] + rate * (label - predicted) * x[j];
}
// not necessary for learning
// lik += label * Math.log(classify(x)) + (1-label) * Math.log(1- classify(x));
}
// System.out.println("iteration: " + n + " " + Arrays.toString(weights) + " mle: " + lik);
}
}
public double classify(int[] x) {
double logit = .0;
for (int i=0; i<weights.length;i++) {
logit += weights[i] * x[i];
}
return sigmoid(logit);
}
public static class Instance {
public int label;
public int[] x;
public Instance(int label, int[] x) {
this.label = label;
this.x = x;
}
}
public static List<Instance> readDataSet(String path) {
List<Instance> dataset = new ArrayList<Instance>();
Scanner scanner = null;
AssetManager am = mContext.getAssets();
try {
InputStream is = am.open(path);
scanner = new Scanner(new InputStreamReader(is));
while(scanner.hasNextLine()) {
String line = scanner.nextLine();
if (line.startsWith("#")) {
continue;
}
String[] columns = line.split("\\s+");
// skip first column and last column is the label
int i = 1;
int[] data = new int[columns.length-2];
for (i=1; i<columns.length-1; i++) {
data[i-1] = Integer.parseInt(columns[i]);
}
int label = Integer.parseInt(columns[i]);
Instance instance = new Instance(label, data);
dataset.add(instance);
}
}catch (IOException e) {
e.printStackTrace();
} finally {
if (scanner != null)
scanner.close();
}
return dataset;
}
}
Моя основная деятельность просто вызывает logisticRegression.classify с n = 1
Я использую самодельный набор данных, который содержит оценки от 0 до 10, и я ожидаю, что очень низкая вероятность усталости близка к 0 и очень высока, когда близка к 10. Независимо от баллов, с которыми я тестирую, я получаю вероятность около 0,5, когда близко к 0 и 0,7, когда близко к 10. Я дважды проверил с помощью внешнего инструмента логистической регрессии и того же набора данных, и результат, который я получаю, намного лучше. Я просмотрел код и сам не нашел никаких ошибок, но он не работает должным образом. Вот также часть моего набора данных, где второй столбец - оценка, а третий - метка:
TestData
1 10 1
2 9 1
3 10 1
4 9 1
5 5 0
6 4 0
7 3 0
8 2 0
9 10 1
10 10 1
11 7 1
12 6 1
13 6 0
14 5 0
15 10 1
16 10 1
17 9 1
18 2 0
19 1 0