Я пытаюсь понять, что матрица коэффициентов в логистической регрессии в pyspark (особенно лассо)? Это логистическая регрессия, поэтому я думаю, что веса должны быть просто в форме 1xn для n функций.
Также, как мне отобразить объекты и их коэффициенты вместе, чтобы увидеть, у какого объекта есть 0 коэффициентов.
Примечание: я использую MulticlassClassificationEvaluator, потому что в BinaryClassificationEvaluator нет возможности взвешенного отзыва. Это не проблема, я считаю?
Я получаю разреженную матрицу. Я делаю двоичную классификацию с 23 функциями, но получаю разреженную матрицу 3X23. Не должно быть 1X23.
from pyspark.ml.evaluation import BinaryClassificationEvaluator,MulticlassClassificationEvaluator
from pyspark.ml.classification import LogisticRegression
evaluator=MulticlassClassificationEvaluator(metricName="weightedRecall",predictionCol='prediction',labelCol='label')
lr = LogisticRegression(labelCol='label',
featuresCol="features",weightCol="classWeights")
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
paramGrid = ParamGridBuilder()\
.addGrid(lr.elasticNetParam,[1.0])\
.addGrid(lr.maxIter,[10])\
.addGrid(lr.regParam,[0.01, 0.5, 2.0]) \
.build()
cv = CrossValidator(estimator=lr, estimatorParamMaps=paramGrid,
evaluator=evaluator, numFolds=2)
%time cvModel = cv.fit(train_df)
predict_test_hyp=cvModel.transform(test_df)
coef=best.coefficientMatrix
Вывод для этого при преобразовании в плотную матрицу
DenseMatrix ([
[ 5.66693393e-01, 0.00000000e+00, -8.52316465e-09,
6.64542431e-03, 0.00000000e+00, 5.34390416e-02,
-4.51579298e-02, 0.00000000e+00, 0.00000000e+00,
-4.51579298e-02, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 3.16000659e-02, 9.72526723e-01,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, -4.70342863e-03,
0.00000000e+00, -1.75045505e-01],
[-2.41420676e-01, 1.67133402e-07, 8.30360611e-08,
-6.51168658e-03, 1.04331113e+00, 2.14565081e-01,
4.11803774e-01, 0.00000000e+00, 9.97492835e-08,
4.11803774e-01, 0.00000000e+00, 0.00000000e+00,
2.56259268e-01, 2.52040849e-02, -8.22050592e-01,
6.76655408e-01, 0.00000000e+00, 1.37646488e-01,
0.00000000e+00, 4.33575071e-02, 6.79627660e-03,
3.14889764e-01, 4.54933918e-01],
[-1.09990806e-01, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, -6.04906840e-01, -4.63173578e-01,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, -8.73277227e-02, 0.00000000e+00,
-1.98500866e-01, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, -3.00089689e-01, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00]])