Проблема: предсказания pyspark.ml.regression.RandomForestRegressor по умолчанию являются дискретными выходными данными, соответствующими листу, который лучше соответствует входным данным.Он не интерполирует между двумя ближайшими листьями, что является моим желаемым поведением.
Вопрос: Как настроить pyspark.ml.regression.RandomForestRegressor для интерполяции выходных данных?
Я не могу найти такую опцию здесь: class pyspark.ml.regression.RandomForestRegressor
Воспроизведение проблемы: следуйте этому уроку: MLlib: Основное руководство
from pyspark.ml import Pipeline
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.feature import VectorIndexer
from pyspark.ml.evaluation import RegressionEvaluator
# Load and parse the data file, converting it to a DataFrame.
data = spark.read.format("libsvm").load("/FileStore/tables/sample_libsvm_data.txt")
# Automatically identify categorical features, and index them.
# Set maxCategories so features with > 4 distinct values are treated as continuous.
featureIndexer =\
VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])
# Train a RandomForest model.
rf = RandomForestRegressor(featuresCol="indexedFeatures")
# Chain indexer and forest in a Pipeline
pipeline = Pipeline(stages=[featureIndexer, rf])
# Train model. This also runs the indexer.
model = pipeline.fit(trainingData)
# Make predictions.
predictions = model.transform(testData)
# Select example rows to display.
predictions.select("prediction", "label", "features").show(10)
ВЫХОД
+----------+-----+--------------------+
|prediction|label| features|
+----------+-----+--------------------+
| 0.0| 0.0|(692,[95,96,97,12...|
| 0.0| 0.0|(692,[121,122,123...|
| 0.0| 0.0|(692,[122,123,148...|
| 0.0| 0.0|(692,[124,125,126...|
| 0.0| 0.0|(692,[124,125,126...|
| 0.0| 0.0|(692,[124,125,126...|
| 0.0| 0.0|(692,[124,125,126...|
| 0.2| 0.0|(692,[125,126,127...|
| 0.05| 0.0|(692,[126,127,128...|
| 0.05| 0.0|(692,[127,128,129...|
+----------+-----+--------------------+