Я строю дерево решений в Писпарке. Поэтому использовали StringIndexer для преобразования строкового атрибута в цифру c для дальнейших вычислений
from pyspark.ml.feature import StringIndexer
indexer = StringIndexer(inputCol="SKL", outputCol="IndexedLabel")
Позже, когда я использую IndexTo String для возврата столбца в формат String, как показано ниже
from pyspark.ml.feature import IndexToString
PredConverter = IndexToString(inputCol='prediction', outputCol='predictionStr',labels=indexer.SKL)
predictions = PredConverter.transform(dtc_preds)
predictions.show()
Я получаю сообщение об ошибке: AttributeError: у объекта 'StringIndexer' нет атрибута 'SKL'
Хотя при проверке имен входных столбцов в объекте класса StringIndexer, как показано ниже, он печатает как SKL как вывод
print(indexer.getInputCol())
Это мой полный код
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer,IndexToString
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline
spark=SparkSession.builder.appName('tree').getOrCreate()
userKnowDf=spark.read.csv('/FileStore/tables/UserKnowModelingDataset_Train.csv',inferSchema=True, header=True)
userKnowDf.describe().show()
indexer = StringIndexer(inputCol="SKL", outputCol="IndexedLabel")
indexedUserKnow = indexer.fit(userKnowDf).transform(userKnowDf)
df= indexedUserKnow.selectExpr(("IndexedLabel"),("SKL as label"),("SST"),("SRT"),("SAT"),("SAP"),("SEP"));
assembler=VectorAssembler(inputCols=['SST','SRT','SAT','SAP','SEP'],outputCol='features')
LRdf = assembler.transform(df).select('IndexedLabel','label','features')
split=LRdf.randomSplit([0.7,0.3])
trainingdata=split[0]
testdata=split[1]
dtc = DecisionTreeClassifier(labelCol='IndexedLabel', featuresCol='features')
dtc_model=dtc.fit(trainingdata)
dtc_preds = dtc_model.transform(testdata)
dtc_preds.show()
PredConverter = IndexToString(inputCol='prediction', outputCol='predictionStr',labels=indexer.SKL) // Getting error here
predictions = PredConverter.transform(dtc_preds)
predictions.show()