K-NN в Pyspark - PullRequest
       58

K-NN в Pyspark

0 голосов
/ 14 июля 2020

Следующий код для набора данных о покере был закодирован для классификации набора данных о покере, имеющего 10 характеристик (все числа c) и метку класса 10 (все числа c). Я использовал функцию K-NN sklearn в Pyspark с настраиваемой функцией расстояния. Выдает ошибку при трансляции модели K-NN и прогнозировании тестовой метки. Когда я не использую пользовательскую функцию, она не показывает никаких ошибок. Почему это происходит?

x=sc.textFile("/home/ritesh/Spark/poker100.txt")
def parseLine(line):
    cols = line.split(',') # split the txt file with ','
    # label is the last column
    label = cols[-1]
    # vector is every column, except the label
    vector = cols[:-1]                    
    vector = [element for i, element in enumerate(vector) ]            
    # convert each value from string to float
    vector = np.array(vector, dtype=np.float)
    vector=vector.tolist()
    return (label, vector)

x= x.map(parseLine)
train,test=x.randomSplit([0.7,0.3],seed=100)
train=train.map(lambda x: (x[0], x[1]))
test=test.map(lambda x: (x[0],x[1]))
X=train.map(lambda x: x[1])
#collect traing data
X=X.collect()
Y=train.map(lambda x: x[0]) 
#collect training label
Y=Y.collect()
y=test.map(lambda x: x[0])
# collect testing label
y=y.collect()

import math
def dist(x,y):#Euc. distance function to calculate distance between training and testing data
    return np.sqrt(np.sum((x-y)**2))
import numpy as np
from sklearn.neighbors.ball_tree import BallTree
BallTree.valid_metrics
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
knn=KNeighborsClassifier(n_neighbors=3,algorithm='ball_tree', metric= dist)
model=knn.fit(X,Y) # fit KNN model
model=sc.broadcast(model)
testdata=test.map(lambda x: model.value.predict(np.array(x[1],dtype="float64").reshape(1,-1))) #predict test data 
y_pred=testdata.collect()

при запуске выдает ошибку:

Py4JJavaError                             
Traceback (most recent call last)
<ipython-input-113-a20ddffd3048> in <module>()
      1 model=sc.broadcast(model)
      2 testdata=test.map(lambda x: model.value.predict(np.array(x[1],dtype="float64").reshape(1,-1)))
----> 3 y_pred=testdata.collect()

/apps/spark-2.4.3/python/pyspark/rdd.py in collect(self)
    814         """
    815         with SCCallSiteSync(self.context) as css:
--> 816             sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
    817         return list(_load_from_socket(sock_info, self._jrdd_deserializer))
    818 

/apps/spark-2.4.3/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/apps/spark-2.4.3/python/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/apps/spark-2.4.3/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 43.0 failed 1 times, most recent failure: Lost task 1.0 in stage 43.0 (TID 87, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/worker.py", line 377, in main
    process()
  File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/worker.py", line 372, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/serializers.py", line 393, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-113-a20ddffd3048>", line 2, in <lambda>
  File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/broadcast.py", line 148, in value
    self._value = self.load_from_path(self._path)
  File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/broadcast.py", line 125, in load_from_path
    return self.load(f)
  File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/broadcast.py", line 131, in load
    return pickle.load(file)
AttributeError: Can't get attribute 'dist' on <module 'pyspark.daemon' from '/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/daemon.py'>
...