Получение BinaryClassification FastTree FeatureNames для анализа PFI - PullRequest
2 голосов
/ 12 апреля 2019

Я построил простую модель BinaryClassification FastTree в ML.net 1.0.0, используя подмножество столбцов в моем trainingDataView. Теперь я хочу выполнить анализ PFI, но не могу выделить только столбцы / функции, используемые в модели, по сравнению со всеми столбцами в IDataView.

Я ссылаюсь на пример по по этой ссылке PFI для двоичной классификации.

var trainingDataView = mlContext.Data.LoadFromTextFile<FPPCNTKData>(TrainDataPath, hasHeader: false, separatorChar: ' ');

Var pipeline = mlContext.Transforms.Concatenate("Features",
                                                "mCalc_FPP_Legs_Range",
                                                "mCalc_FPP_Legs_Ticks",
                                                "mCalc_FPP_Legs_Bars",
                                                "mCalc_FPP_Legs_TMins",
                                                "mCalc_FPP_Diag_RangeBars",
                                                "mCalc_FPP_Diag_RangeTMins",
                                                "mCalc_FPP_Diag_TicksBars",
                                                "mCalc_FPP_Diag_TicksTMins",
                                                "mCalc_XD_XA_Mult_Ticks",
                                                "mCalc_AB_XA_Mult_Ticks",
                                                "mCalc_AD_XA_Mult_Ticks",
                                                "mCalc_BC_XA_Mult_Ticks",
                                                "mCalc_BC_AB_Mult_Ticks",
                                                "mCalc_CD_AB_Mult_Ticks",
                                                "mCalc_CD_BC_Mult_Ticks",
                                                "mCalc_CD_BD_Mult_Ticks")
     .Append(mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: "mHiProfitOneHot", featureColumnName: "Features"));

var trainedModel = pipeline.Fit(trainingDataView);

Как вы можете видеть ниже, так как я собираю имена элементов из исходного trainingDataView, а не то, что использовалось в модели, элементы PFI помечены неправильно.

//// Compute the permutation metrics using the properly normalized data.
var linearPredictor = trainedModel.LastTransformer;
var transformedData = trainedModel.Transform(trainingDataView);
var permutationMetrics = mlContext.BinaryClassification.PermutationFeatureImportance(
                linearPredictor, transformedData, labelColumnName: "mHiProfitOneHot", permutationCount: 3);

// Now let's look at which features are most important to the model overall.
// Get the feature indices sorted by their impact on AUC.
var sortedIndices = permutationMetrics.Select((MetricStatistics, index) => new { index, metrics.AreaUnderRocCurve })
                .OrderByDescending(feature => Math.Abs(feature.AreaUnderRocCurve))
                .Select(feature => feature.index);

// Get the feature names from the training set
var featureNames =
    trainingDataView.Schema.AsEnumerable()
    .Select(column => column.Name) // Get the column names
    .Where(name => name != "mHiProfitOneHot") // Drop the Label
    .ToArray();


Console.WriteLine("Feature\tModel Weight\tChange in AUC\t95% Confidence in the Mean Change in AUC");
var auc = permutationMetrics.Select(x => x.AreaUnderRocCurve).ToArray();
foreach (int i in sortedIndices)
{
    Console.WriteLine("{0}\t{1:0.00}\t{2:G4}\t{3:G4}",
         featureNames[i],
         linearPredictor.Model.SubModel.TrainedTreeEnsemble.TreeWeights[i],
         auc[i].Mean,
         1.96 * auc[i].StandardError);
}

Можно ли извлечь подмножество имен элементов непосредственно из модели? Спасибо.

1 Ответ

0 голосов
/ 16 апреля 2019

Вы можете выполнить поиск в вашей модели (при условии, что это TransformerChain, как, кажется, в вашем случае), ища ColumnConcatenatingTransformer и получая имена входных столбцов.

string[] columnNames = (model
                    .FirstOrDefault(t => t is ColumnConcatenatingTransformer) as ColumnConcatenatingTransformer)
                    ?.Columns
                    ?.FirstOrDefault(c => c.outputColumnName == "Features")
                    .inputColumnNames;
Console.WriteLine(String.Join(", ", columnNames));
...