Уверенность мультиклассовой классификации с ML.Net - PullRequest
0 голосов
/ 27 сентября 2018

Я нашел идеальное введение в ML.NET: https://www.codeproject.com/Articles/1249611/Machine-Learning-with-ML-Net-and-Csharp-VB-Net. Это помогло мне решить некоторые вопросы с ML.NET.

Но один из них все еще актуален:

Когда я отправляю какой-либо текст в детектор языка (пример LanguageDetection), я всегда получаю результат.Даже если классификация не уверенна для очень короткого фрагмента текста.Могу ли я получить информацию об уверенности в мультиклассовой классификации?Или вероятность принадлежности к какому-либо классу, чтобы использовать его во втором проходе алгоритма, который использует языки соседних предложений?

1 Ответ

0 голосов
/ 29 сентября 2018

Согласно подсказке @ Jon, я изменил исходный пример из CodeProject.Этот код можно найти по следующей ссылке: https://github.com/sotnyk/LanguageDetector/tree/Code-for-stackoverflow-52536943

Основным является (как предложил Джон) добавление поля:

public float[] Score;

в класс ClassPrediction.

Если это поле существует, мы получили вероятности / достоверности мультиклассовой классификации для каждого класса.

Но у нас есть другая трудность с оригинальным примером.Он использует значения с плавающей запятой в качестве метки категории.Но это не показатели в массиве очков.Чтобы сопоставить индексы оценок с категориями, мы должны использовать метод TryGetScoreLabelNames:

if (!model.TryGetScoreLabelNames(out var scoreClassNames))
    throw new Exception("Can't get score classes");

Но этот метод не работает с метками классов в качестве значений с плавающей запятой.Поэтому я изменил исходные файлы .tsv и поля ClassificationData.LanguageClass и ClassPrediction.Class, чтобы использовать строковые метки в качестве имен классов.

Дополнительные изменения, которые не упомянуты непосредственно в теме вопроса:

  • Обновленная версия nuget-пакетов.
  • Мне интересно работать с классификатором lightGBM (он показывает лучшее качество для меня).Но в текущей версии его nuget-пакет содержит ошибку для не-NetCore приложений.Итак, я изменил платформу примеров на NetCore20 / Standard.
  • Не комментируемая модель использует классификатор lightGBM.

Баллы для каждого языка, напечатанного в приложении с именем Prediction.Теперь эта часть кода выглядит следующим образом:

internal static async Task<PredictionModel<ClassificationData, ClassPrediction>> PredictAsync(
    string modelPath,
    IEnumerable<ClassificationData> predicts = null,
    PredictionModel<ClassificationData, ClassPrediction> model = null)
{
    if (model == null)
    {
        new LightGbmArguments();
        model = await PredictionModel.ReadAsync<ClassificationData, ClassPrediction>(modelPath);
    }

    if (predicts == null) // do we have input to predict a result?
        return model;

    // Use the model to predict the positive or negative sentiment of the data.
    IEnumerable<ClassPrediction> predictions = model.Predict(predicts);

    Console.WriteLine();
    Console.WriteLine("Classification Predictions");
    Console.WriteLine("--------------------------");

    // Builds pairs of (sentiment, prediction)
    IEnumerable<(ClassificationData sentiment, ClassPrediction prediction)> sentimentsAndPredictions =
        predicts.Zip(predictions, (sentiment, prediction) => (sentiment, prediction));

    if (!model.TryGetScoreLabelNames(out var scoreClassNames))
        throw new Exception("Can't get score classes");

    foreach (var (sentiment, prediction) in sentimentsAndPredictions)
    {
        string textDisplay = sentiment.Text;

        if (textDisplay.Length > 80)
            textDisplay = textDisplay.Substring(0, 75) + "...";

        string predictedClass = prediction.Class;

        Console.WriteLine("Prediction: {0}-{1} | Test: '{2}', Scores:",
            prediction.Class, predictedClass, textDisplay);
        for(var l = 0; l < prediction.Score.Length; ++l)
        {
            Console.Write($"  {l}({scoreClassNames[l]})={prediction.Score[l]}");
        }
        Console.WriteLine();
        Console.WriteLine();
    }
    Console.WriteLine();

    return model;
}

}

...