Следующий код для набора данных о покере был закодирован для классификации набора данных о покере, имеющего 10 характеристик (все числа c) и метку класса 10 (все числа c). Я использовал функцию K-NN sklearn в Pyspark с настраиваемой функцией расстояния. Выдает ошибку при трансляции модели K-NN и прогнозировании тестовой метки. Когда я не использую пользовательскую функцию, она не показывает никаких ошибок. Почему это происходит?
x=sc.textFile("/home/ritesh/Spark/poker100.txt")
def parseLine(line):
cols = line.split(',') # split the txt file with ','
# label is the last column
label = cols[-1]
# vector is every column, except the label
vector = cols[:-1]
vector = [element for i, element in enumerate(vector) ]
# convert each value from string to float
vector = np.array(vector, dtype=np.float)
vector=vector.tolist()
return (label, vector)
x= x.map(parseLine)
train,test=x.randomSplit([0.7,0.3],seed=100)
train=train.map(lambda x: (x[0], x[1]))
test=test.map(lambda x: (x[0],x[1]))
X=train.map(lambda x: x[1])
#collect traing data
X=X.collect()
Y=train.map(lambda x: x[0])
#collect training label
Y=Y.collect()
y=test.map(lambda x: x[0])
# collect testing label
y=y.collect()
import math
def dist(x,y):#Euc. distance function to calculate distance between training and testing data
return np.sqrt(np.sum((x-y)**2))
import numpy as np
from sklearn.neighbors.ball_tree import BallTree
BallTree.valid_metrics
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
knn=KNeighborsClassifier(n_neighbors=3,algorithm='ball_tree', metric= dist)
model=knn.fit(X,Y) # fit KNN model
model=sc.broadcast(model)
testdata=test.map(lambda x: model.value.predict(np.array(x[1],dtype="float64").reshape(1,-1))) #predict test data
y_pred=testdata.collect()
при запуске выдает ошибку:
Py4JJavaError
Traceback (most recent call last)
<ipython-input-113-a20ddffd3048> in <module>()
1 model=sc.broadcast(model)
2 testdata=test.map(lambda x: model.value.predict(np.array(x[1],dtype="float64").reshape(1,-1)))
----> 3 y_pred=testdata.collect()
/apps/spark-2.4.3/python/pyspark/rdd.py in collect(self)
814 """
815 with SCCallSiteSync(self.context) as css:
--> 816 sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
817 return list(_load_from_socket(sock_info, self._jrdd_deserializer))
818
/apps/spark-2.4.3/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
1255 answer = self.gateway_client.send_command(command)
1256 return_value = get_return_value(
-> 1257 answer, self.gateway_client, self.target_id, self.name)
1258
1259 for temp_arg in temp_args:
/apps/spark-2.4.3/python/pyspark/sql/utils.py in deco(*a, **kw)
61 def deco(*a, **kw):
62 try:
---> 63 return f(*a, **kw)
64 except py4j.protocol.Py4JJavaError as e:
65 s = e.java_exception.toString()
/apps/spark-2.4.3/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
--> 328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 43.0 failed 1 times, most recent failure: Lost task 1.0 in stage 43.0 (TID 87, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/worker.py", line 377, in main
process()
File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/worker.py", line 372, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/serializers.py", line 393, in dump_stream
vs = list(itertools.islice(iterator, batch))
File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
return f(*args, **kwargs)
File "<ipython-input-113-a20ddffd3048>", line 2, in <lambda>
File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/broadcast.py", line 148, in value
self._value = self.load_from_path(self._path)
File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/broadcast.py", line 125, in load_from_path
return self.load(f)
File "/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/broadcast.py", line 131, in load
return pickle.load(file)
AttributeError: Can't get attribute 'dist' on <module 'pyspark.daemon' from '/apps/spark-2.4.3/python/lib/pyspark.zip/pyspark/daemon.py'>