Я использую этот код в базе данных рейтинга книг. использование 'ALS.predict ()' в функции make_recommendation хорошо работает, но при использовании ALS.predictAll ошибка нарастает. Код 'make_recommendation ()' выглядит следующим образом:
# train best ALS
model = ALS.train(
ratings=train_data,
iterations=best_model_params.get('iterations', None),
rank=best_model_params.get('rank', None),
lambda_=best_model_params.get('lambda_', None),
seed=99)
# get inference rdd
inference_rdd = get_inference_data(ratings_data, df_book, bookId_list)
# inference
predictions = model.predictAll(inference_rdd) # .map(lambda r: (r[1], r[2]))
, это get_inference_data ()
get_inference_data(train_data, df_book, bookId_list)
# get new user id
new_id = train_data.map(lambda r: r[0]).max() + 1
# return inference rdd
return df_book.rdd \
.map(lambda r: r[0]) \
.distinct() \
.filter(lambda x: x not in bookId_list) \
.map(lambda x: (new_id, int(x)))
, и это ошибка:
19/11/03 19:47:47 ERROR TaskSetManager: Task 0 in stage 170.0 failed 1 times; aborting job
Traceback (most recent call last):
File "/Users/apple/Documents/project/codes/shed/fff.py", line 329, in <module>
spark_context=sc)
File "/Users/apple/Documents/project/codes/shed/fff.py", line 303, in make_recommendation
predictions = model.predictAll(inference_rdd) # .map(lambda r: (r[1], r[2]))
File "/usr/local/lib/python3.7/site-packages/pyspark/mllib/recommendation.py", line 149, in predictAll
return self.call("predict", user_product)
File "/usr/local/lib/python3.7/site-packages/pyspark/mllib/common.py", line 146, in call
return callJavaFunc(self._sc, getattr(self._java_model, name), *a)
File "/usr/local/lib/python3.7/site-packages/pyspark/mllib/common.py", line 123, in callJavaFunc
return _java2py(sc, func(*args))
File "/usr/local/lib/python3.7/site-packages/py4j/java_gateway.py", line 1257, in __call__
answer, self.gateway_client, self.target_id, self.name)
File "/usr/local/lib/python3.7/site-packages/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/usr/local/lib/python3.7/site-packages/py4j/protocol.py", line 328, in get_return_value
format(target_id, ".", name), value)
py4j.protocol.Py4JJavaError: An error occurred while calling o146.predict.
ив другой части ошибок:
19/11/03 19:47:47 WARN TaskSetManager: Lost task 0.0 in stage 170.0 (TID 140, localhost, executor driver): java.lang.ClassCastException: java.lang.Long cannot be cast
Я использую Java 1.8 и Spark 2.4.4, на MacOs, и я только учусь pyspark