Обученный SVM выдает только 1,0 в результате, несмотря на ошибку обучения 0,0 - PullRequest
0 голосов
/ 17 февраля 2019

Я пытаюсь классифицировать набор данных !В этом наборе данных первый столбец является идеальным результатом, а остальные 20 столбцов являются входными данными.

Проблема, которая здесь возникает для меня, состоит в том, что SVM, обученный на наборе данных (в этом случае 80% используется для обучения), показывает ошибку обучения 0,0, но всегда прогнозирует 1,0 как результат.

Я разделил набор на две части, одну для обучения (80% данных) и 20% для классификации.Данные представляют собой объединение двух коротких временных рядов значений RSI (один период 2 и один период 14).

Почему SVM ведет себя таким образом?И могу ли я что-то сделать, чтобы избежать этого?Я думал, 0,0 ошибки обучения будет означать, что на тренировочном наборе SVM больше не допускает ошибок.Судя по результатам, это кажется ложным.

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.encog.Encog;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.svm.SVM;
import org.encog.ml.svm.training.SVMTrain;

public class SVMTest {

    public static void main(String[] args) {
        List<String> lines = readFile("/home/wens/mlDataSet.csv");
        double[][] trainingSetData = getInputData(lines, 0, lines.size()/10*8);
        double[][] trainingIdeal = getIdeal(lines, 0, lines.size()/10*8);
        MLDataSet trainingSet = new BasicMLDataSet(trainingSetData, trainingIdeal);
        double[][] classificationSetData = getInputData(lines, lines.size()/10*8, lines.size());
        double[][] classificationIdeal = getIdeal(lines, lines.size()/10*8, lines.size());
        MLDataSet classificationSet = new BasicMLDataSet(classificationSetData, classificationIdeal);

        SVM svm = new SVM(20,false);
        final SVMTrain train = new SVMTrain(svm, trainingSet);
        train.iteration();
        train.finishTraining();
        System.out.println("training error: " + train.getError());

        System.out.println("SVM Results:");
        for(MLDataPair pair: classificationSet ) {
            final MLData output = svm.compute(pair.getInput());
            System.out.println("actual: " + output.getData(0) + "\tideal=" + pair.getIdeal().getData(0));
        }

        Encog.getInstance().shutdown();
    }

    private static List<String> readFile(String filepath){
        List<String> res = new ArrayList<>();
        try {
            File f = new File(filepath);
            BufferedReader b = new BufferedReader(new FileReader(f));
            String readLine = "";
            while ((readLine = b.readLine()) != null) {
                res.add(readLine);
            }

        } catch (IOException e) {
            e.printStackTrace();
        }
        return res;
    }

    private static double[][] getInputData(List<String> lines, int start, int end){
        double[][] res = new double[end-start][20];
        int cnt = 0;
        for(int i=start; i<end; i++){
            String[] tmp = lines.get(i).split("\t");
            for(int j=1; j<tmp.length; j++){
                res[cnt][j-1] = Double.parseDouble(tmp[j]);
            }
            cnt++;
        }
        return res;
    }

    private static double[][] getIdeal(List<String> lines, int start, int end){
        double[][] res = new double[end-start][1];
        int cnt = 0;
        for(int i=start; i<end; i++){
            String[] tmp = lines.get(i).split("\t");
            res[cnt][0] = Double.parseDouble(tmp[0]);
            cnt++;
        }
        return res;
    }
}
...