Я использую пакет TPOT с Dask, и я сталкиваюсь с исключением при использовании удаленного кластера Dask
Контекст вопроса
Я создал кластер Dask в Google Cloud Container Engine согласно документации, http://docs.dask.org/en/latest/setup/kubernetes-helm.html
Я включил в качестве зависимостей для рабочих в conf.yaml следующее:
- name: EXTRA_PIP_PACKAGES
value: s3fs tpot scikit-learn featuretools dask-ml[complete] dask[complete] deap xgboost --upgrade
Процесс воспроизведения проблемы
client = Client(address='CLUSTER-IPADDRESS:8786')
tpot = TPOTRegressor(generations=20, population_size=30, verbosity=2, use_dask=True)
tpot.fit(X_train, y_train)
Ожидаемый результат
Нет ошибок или исключений при установке
Текущий результат
После выполнения tpot.fit()
Клиентская сторона генерирует трассировку:
Imputing missing values in feature set
/usr/local/lib/python3.7/site-packages/sklearn/utils/deprecation.py:58: DeprecationWarning: Class Imputer is deprecated; Imputer was deprecated in version 0.20 and will be removed in 0.22. Import impute.SimpleImputer from sklearn instead.
warnings.warn(msg, category=DeprecationWarning)
Optimization Progress
0% 0/630 [00:00<?, ?pipeline/s]
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
/usr/local/lib/python3.7/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
1492 try:
-> 1493 st = self.futures[key]
1494 exception = st.exception
KeyError: 'nanmean-4dd454cc-ecd5-48a5-8222-1603a54abf65'
During handling of the above exception, another exception occurred:
CancelledError Traceback (most recent call last)
/usr/local/lib/python3.7/site-packages/tpot/base.py in fit(self, features, target, sample_weight, groups)
660 verbose=self.verbosity,
--> 661 per_generation_function=self._check_periodic_pipeline
662 )
/usr/local/lib/python3.7/site-packages/tpot/gp_deap.py in eaMuPlusLambda(population, toolbox, mu, lambda_, cxpb, mutpb, ngen, pbar, stats, halloffame, verbose, per_generation_function)
229
--> 230 fitnesses = toolbox.evaluate(invalid_ind)
231 for ind, fit in zip(invalid_ind, fitnesses):
/usr/local/lib/python3.7/site-packages/tpot/base.py in _evaluate_individuals(self, individuals, features, target, sample_weight, groups)
1223 warnings.simplefilter('ignore')
-> 1224 result_score_list = list(dask.compute(*result_score_list))
1225
/usr/local/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
396 postcomputes = [x.__dask_postcompute__() for x in collections]
--> 397 results = schedule(dsk, keys, **kwargs)
398 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
/usr/local/lib/python3.7/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
2337 results = self.gather(packed, asynchronous=asynchronous,
-> 2338 direct=direct)
2339 finally:
/usr/local/lib/python3.7/site-packages/distributed/client.py in gather(self, futures, errors, maxsize, direct, asynchronous)
1661 direct=direct, local_worker=local_worker,
-> 1662 asynchronous=asynchronous)
1663
/usr/local/lib/python3.7/site-packages/distributed/client.py in sync(self, func, *args, **kwargs)
675 else:
--> 676 return sync(self.loop, func, *args, **kwargs)
677
/usr/local/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, *args, **kwargs)
276 if error[0]:
--> 277 six.reraise(*error[0])
278 else:
/usr/local/lib/python3.7/site-packages/six.py in reraise(tp, value, tb)
692 raise value.with_traceback(tb)
--> 693 raise value
694 finally:
/usr/local/lib/python3.7/site-packages/distributed/utils.py in f()
261 future = gen.with_timeout(timedelta(seconds=timeout), future)
--> 262 result[0] = yield future
263 except Exception as exc:
/usr/local/lib/python3.7/site-packages/tornado/gen.py in run(self)
1132 try:
-> 1133 value = future.result()
1134 except Exception:
/usr/local/lib/python3.7/site-packages/tornado/gen.py in run(self)
1140 try:
-> 1141 yielded = self.gen.throw(*exc_info)
1142 finally:
/usr/local/lib/python3.7/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
1498 CancelledError(key),
-> 1499 None)
1500 else:
/usr/local/lib/python3.7/site-packages/six.py in reraise(tp, value, tb)
692 raise value.with_traceback(tb)
--> 693 raise value
694 finally:
CancelledError: nanmean-4dd454cc-ecd5-48a5-8222-1603a54abf65
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
<ipython-input-13-fc63521ba7ad> in <module>
----> 1 tpot.fit(X_train, y_train)
/usr/local/lib/python3.7/site-packages/tpot/base.py in fit(self, features, target, sample_weight, groups)
691 # raise the exception if it's our last attempt
692 if attempt == (attempts - 1):
--> 693 raise e
694 return self
695
/usr/local/lib/python3.7/site-packages/tpot/base.py in fit(self, features, target, sample_weight, groups)
682 self._pbar.close()
683
--> 684 self._update_top_pipeline()
685 self._summary_of_best_pipeline(features, target)
686 # Delete the temporary cache before exiting
/usr/local/lib/python3.7/site-packages/tpot/base.py in _update_top_pipeline(self)
756 # If user passes CTRL+C in initial generation, self._pareto_front (halloffame) shoule be not updated yet.
757 # need raise RuntimeError because no pipeline has been optimized
--> 758 raise RuntimeError('A pipeline has not yet been optimized. Please call fit() first.')
759
760 def _summary_of_best_pipeline(self, features, target):
RuntimeError: A pipeline has not yet been optimized. Please call fit() first.
Просматривая журналы рабочих кластера, вы обнаружите:
tornado.application - ERROR - Exception in callback functools.partial(<function wrap.<locals>.null_wrapper at 0x7f04974ea378>, <Future finished exception=TypeError("No dispatch for <class 'xgboost.sklearn.XGBRegressor'>",)>)
Traceback (most recent call last):
File "/opt/conda/lib/python3.6/site-packages/tornado/ioloop.py", line 758, in _run_callback
ret = callback()
File "/opt/conda/lib/python3.6/site-packages/tornado/stack_context.py", line 300, in null_wrapper
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.6/site-packages/tornado/ioloop.py", line 779, in _discard_future_result
future.result()
File "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", line 1141, in run
yielded = self.gen.throw(*exc_info)
File "/opt/conda/lib/python3.6/site-packages/distributed/worker.py", line 661, in handle_scheduler
self.ensure_computing])
File "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", line 1133, in run
value = future.result()
File "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", line 1141, in run
yielded = self.gen.throw(*exc_info)
File "/opt/conda/lib/python3.6/site-packages/distributed/core.py", line 386, in handle_stream
msgs = yield comm.read()
File "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", line 1133, in run
value = future.result()
File "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", line 1141, in run
yielded = self.gen.throw(*exc_info)
File "/opt/conda/lib/python3.6/site-packages/distributed/comm/tcp.py", line 206, in read
deserializers=deserializers)
File "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", line 1133, in run
value = future.result()
File "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", line 326, in wrapper
yielded = next(result)
File "/opt/conda/lib/python3.6/site-packages/distributed/comm/utils.py", line 79, in from_frames
res = _from_frames()
File "/opt/conda/lib/python3.6/site-packages/distributed/comm/utils.py", line 65, in _from_frames
deserializers=deserializers)
File "/opt/conda/lib/python3.6/site-packages/distributed/protocol/core.py", line 131, in loads
value = _deserialize(head, fs, deserializers=deserializers)
File "/opt/conda/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 178, in deserialize
return loads(header, frames)
File "/opt/conda/lib/python3.6/site-packages/distributed/protocol/serialize.py", line 48, in dask_loads
loads = dask_deserialize.dispatch(typ)
File "/opt/conda/lib/python3.6/site-packages/dask/utils.py", line 406, in dispatch
raise TypeError("No dispatch for {0}".format(cls))
TypeError: No dispatch for <class 'xgboost.sklearn.XGBRegressor'>
У кого-нибудь есть опыт исправления этой проблемы в Dask?