Как вычислить вероятность многоклассового предсказания с помощью libsvm? - PullRequest
4 голосов
/ 04 мая 2010

Я использую libsvm , и документация заставляет меня поверить, что есть способ вывести предполагаемую вероятность точности классификации. Это так? И если да, то может ли кто-нибудь привести наглядный пример того, как это сделать в коде?

В настоящее время я использую библиотеки Java следующим образом

    SvmModel model = Svm.svm_train(problem, parameters);
    SvmNode x[] = getAnArrayOfSvmNodesForProblem();
    double predictedValue = Svm.svm_predict(model, x);

Ответы [ 2 ]

7 голосов
/ 04 мая 2010

Учитывая ваш фрагмент кода, я собираюсь предположить, что вы хотите использовать Java API, поставляемый с libSVM , а не более подробный, предоставляемый jlibsvm .

Чтобы включить прогнозирование с оценками вероятности, обучите модель с полем svm_parameter вероятность , установленным в 1 .Затем просто измените свой код так, чтобы он вызывал метод svm svm_predict_probability вместо svm_predict.

Изменяя ваш фрагмент, мы имеем:

parameters.probability = 1;
svm_model model = svm.svm_train(problem, parameters);

svm_node x[] = problem.x[0]; // let's try the first data pt in problem
double[] prob_estimates = new double[NUM_LABEL_CLASSES]; 
svm.svm_predict_probability(model, x, prob_estimates);

Стоит знать, что обучение с мультиклассовыми оценками вероятности может изменить предсказания , сделанные классификатором.Подробнее об этом см. Вопрос Расчет ближайшего совпадения для пары Среднее / Стддев с LibSVM .

1 голос
/ 03 февраля 2015

Принятый ответ работал как шарм. Обязательно установите probability = 1 во время тренировки.

Если вы пытаетесь отбросить прогноз, когда достоверность не достигнута с помощью порога, вот пример кода:

double confidenceScores[] = new double[model.nr_class];
svm.svm_predict_probability(model, svmVector, confidenceScores);

/*System.out.println("text="+ text);
for (int i = 0; i < model.nr_class; i++) {
    System.out.println("i=" + i + ", labelNum:" + model.label[i] + ", name=" + classLoadMap.get(model.label[i]) + ", score="+confidenceScores[i]);
}*/

//finding max confidence; 
int maxConfidenceIndex = 0;
double maxConfidence = confidenceScores[maxConfidenceIndex];
for (int i = 1; i < confidenceScores.length; i++) {
    if(confidenceScores[i] > maxConfidence){
        maxConfidenceIndex = i;
        maxConfidence = confidenceScores[i];
    }
}

double threshold = 0.3; // set this based data & no. of classes
int labelNum = model.label[maxConfidenceIndex];
// reverse map number to name
String targetClassLabel = classLoadMap.get(labelNum); 
LOG.info("classNumber:{}, className:{}; confidence:{}; for text:{}",
        labelNum, targetClassLabel, (maxConfidence), text);
if (maxConfidence < threshold ) {
    LOG.info("Not enough confidence; threshold={}", threshold);
    targetClassLabel = null;
}
return targetClassLabel;
...