Я работаю с Apache Spark MlLib версии 2.11 на Java.Мне нужно передать в RandomForestClassifier как категориальные, так и числовые функции (строки и числа).
Какой API лучше использовать в таком случае?Пример был бы очень полезен.
Редактировать
Я пытался использовать VectorIndexer, но он принимает только цифры, и я не мог понять, как интегрировать OneHotEncoder в него,Кроме того, мне не ясно, как определить, какие функции являются категориальными, а какие числовыми.Где мне нужно установить все возможные категории?
Вот код, который я пробовал:
StructType schema = DataTypes.createStructType(new StructField[] {
new StructField("label", DataTypes.StringType, false, Metadata.empty()),
new StructField("features", new ArrayType(DataTypes.StringType, false), false,
Metadata.empty()),
});
JavaRDD<Row> rowRDD = trainingData.map(record -> {
List<String> values = new ArrayList<>();
for (String field : fields) {
values.add(record.get(field));
}
return RowFactory.create(record.get(Constants.GROUND_TRUTH), values.toArray(new String[0]));
});
Dataset<Row> trainingDataDataframe = spark.createDataFrame(rowRDD, schema);
StringIndexerModel labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(trainingDataDataframe);
OneHotEncoder encoder = new OneHotEncoder()
.setInputCol("features")
.setOutputCol("featuresVec");
Dataset<Row> encoded = encoder.transform(trainingDataDataframe);
VectorIndexerModel featureIndexer = new VectorIndexer()
.setInputCol("featuresVec")
.setOutputCol("indexedFeatures")
.setMaxCategories(maxCategories)
.fit(encoded);
StringIndexerModel featureIndexer = new StringIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.fit(encoded);
RandomForestClassifier rf = new RandomForestClassifier();
.setNumTrees(numTrees);
.setFeatureSubsetStrategy(featureSubsetStrategy);
.setImpurity(impurity);
.setMaxDepth(maxDepth);
.setMaxBins(maxBins);
.setSeed(seed)
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures");
IndexToString labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels());
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter});
PipelineModel model = pipeline.fit(encoded);