Цель состоит в том, чтобы получить точность для каждой модели, чтобы показать, что точность улучшается каждый раз, когда мы улучшаем число наблюдений.Я должен использовать SVMWithSGD для тренировочных данных.
Проблема в том, что я всегда получаю одинаковую точность.
ОС: Linux 4.15.0-46-generic (gcc version 7.3.0 (Ubuntu 7.3.0-16ubuntu3))
Java: java version 1.8.2_201
Python: python 3.6.7 :: Anaconda, Inc..
Анаконда: Conda 4.5.11
Искра: spark-2.4.0-bin-hadoop-2.7
#!usr/bin/env python3
#-*- coding:utf-8 -*-
from pyspark import SparkContext
from pyspark.mllib.classification import SVMWithSGD, SVMModel
from pyspark.mllib.regression import LabeledPoint
def parseTestLP(line):
return LabeledPoint(line[1][0], line[1][1])
def parseTrainLP(line):
return (line[0], (LabeledPoint(line[0], line[1][1])))
def main():
sc = SparkContext(master='local[6]')
rdd = sc.textFile('features', minPartitions=6)\
.map(lambda line: line.split(','))\
.map(lambda x: (x[0]+x[1], (1.0 if x[0].startswith('yorkshire_terrier')\
else 0.0, list(map(float, x[2:])))))
rdd.persist()
rddTrainBase = rdd.map(lambda x: (x[1][0], (x[0], x[1][1])))\
.sampleByKey(False, {0.0: 0.8, 1.0: 0.8})
rddTrainBase.persist()
rddTrainToSubtract = rddTrainBase.map(lambda x: (x[1][0], (x[0], x[1][1])))
rddTest = rdd.subtractByKey(rddTrainToSubtract)
rddTestLP = rddTest.map(parseTestLP)
rddTestLP.persist()
dataTestCount = rddTestLP.count()
testNumber = dataTestCount
rddTrainToSamp = rddTrainBase.map(parseTrainLP)
iterList = [0.05, 0.3]
dico = {}
for part in iterList:
rddTrainSampled = rddTrainToSamp.sampleByKey(False, {0.0: part, 1.0: part})
rddTrainToModel = rddTrainSampled.map(lambda x: x[1])
dataTrainCount = rddTrainToModel.count()
obsNumber = dataTrainCount
model = SVMWithSGD.train(rddTrainToModel, iterations=5)
labsP = rddTestLP.map(lambda p: (p.label, model.predict(p.features)))
accuracy = labsP.filter(lambda lp: lp[0] == lp[1]).count() / float(testNumber)
dico[obsNumber]=accuracy
print(dico)
if __name__ == "__main__":
main()
Я ожидаю, что выходной (например):
{307: 0.9700000000000001, 1737: 0.9899999999999999}
Но я всегда получаю:
{307: 0.9785016286644951, 1737: 0.9785016286644951}
Еще один пример, который я получил:
{280: 0.9761737911702874, 1782: 0.9761737911702874}
Последний:
{294: 0.9675894665766374, 1751: 0.9675894665766374}
Заранее благодарен за помощь.