Я пытался найти эту проблему во многих других сообщениях, но не смог найти решение, открывая его.
Я пытаюсь использовать Spark TrainValidationSplit для настройка гипер-параметров моей модели.
Я запускаю свой код в автономном локальном режиме на 16-гигабайтном 8-ядерном P C со следующей конфигурацией в spark-defaults.conf
: spark.driver.memory 14g
. Это единственная конфигурация, которую я делаю.
Я запускаю примерку с помощью классификатора Gradient Boosted Tree следующим образом:
tvs = TrainValidationSplit(estimator=classifier,
estimatorParamMaps=param_grid,
evaluator=evaluator,
trainRatio=0.8,
parallelism=4)
model = tvs.fit(training)
И я получаю очень длинную ошибку трассировки стека (25000 строк!) без полезной информации, которая выглядит так и, кажется, повторяется:
16:55:57 ERROR SparkUncaughtExceptionHandler: Uncaught exception in thread Thread[Executor task launch worker for task 101611,5,main]
java.lang.StackOverflowError
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2331)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:423)
at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:490)
at sun.reflect.GeneratedMethodAccessor14.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1170)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2233)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2342)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2342)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:423)
at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:490)
at sun.reflect.GeneratedMethodAccessor14.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1170)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2233)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2342)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2342)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:423)
at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:490)
at sun.reflect.GeneratedMethodAccessor14.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
...
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/.../venv/lib/python3.6/site-packages/py4j/java_gateway.py", line 1067, in start
self.socket.connect((self.address, self.port))
ConnectionRefusedError: [Errno 111] Connection refused
...
Когда я тренирую другую модель, такую как Случайный Лес, это работает. Поэтому я предполагаю, что это потому, что я получил больше комбинаций гиперпараметров для тестирования в TrainValidationSplit для GBT (36), чем RF (4), так как списки, использованные для построения показа paramgrid:
RF_MAX_DEPTH = [5, 7]
RF_MAX_BINS = [32]
RF_NUM_TREES = [10, 5]
RF_IMPURITY = ['entropy']
RF_FEATURE_SUBSET_STRATEGY = ['auto']
RF_MIN_INSTANCES_PER_NODE = [1]
RF_MIN_INFO_GAIN = [0.0]
RF_SUBSAMPLING_RATE = [0.8]
RF_MAX_MEMORY_IN_MB = [8192]
RF_CACHE_NODE_IDS = [False]
GBT_MAX_DEPTH = [5, 7, 15, 20]
GBT_MAX_BINS = [32]
GBT_MAX_ITER = [200, 300, 1000]
GBT_STEP_SIZE = [0.05, 0.03, 0.2]
GBT_LOSS_TYPE = ['logistic']
GBT_MIN_INSTANCES_PER_NODE = [1]
GBT_MIN_INFO_GAIN = [0.0]
GBT_SUBSAMPLING_RATE = [0.8]
GBT_MAX_MEMORY_IN_MB = [8192]
GBT_CACHE_NODE_IDS = [True]
Почему TrainValidationSplit выдает эту ошибку и как я могу ее решить (кроме удаления некоторых возможностей гиперпараметров)?
Заранее благодарен за любую помощь:)
EDIT : я удалил некоторые гиперпараметры возможности (так что есть только одна возможная комбинация), и это работает. Так что проблема есть, но я действительно хочу, чтобы она работала.
EDIT 2 : я попытался добавить spark.driver.extraJavaOptions -XX:ThreadStackSize=81920
, так как это, возможно, связано с настройками jvm (так как Java ошибка StackOverflow), но это только сделало его sh позже.