java .lang.StackOverflowError при выполнении подгонки Spark TrainValidationSplit - PullRequest
1 голос
/ 05 февраля 2020

Я пытался найти эту проблему во многих других сообщениях, но не смог найти решение, открывая его.

Я пытаюсь использовать Spark TrainValidationSplit для настройка гипер-параметров моей модели.

Я запускаю свой код в автономном локальном режиме на 16-гигабайтном 8-ядерном P C со следующей конфигурацией в spark-defaults.conf: spark.driver.memory 14g. Это единственная конфигурация, которую я делаю.

Я запускаю примерку с помощью классификатора Gradient Boosted Tree следующим образом:

tvs = TrainValidationSplit(estimator=classifier,
                           estimatorParamMaps=param_grid,
                           evaluator=evaluator,
                           trainRatio=0.8,
                           parallelism=4)
model = tvs.fit(training)

И я получаю очень длинную ошибку трассировки стека (25000 строк!) без полезной информации, которая выглядит так и, кажется, повторяется:

16:55:57 ERROR SparkUncaughtExceptionHandler: Uncaught exception in thread Thread[Executor task launch worker for task 101611,5,main]
java.lang.StackOverflowError
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2331)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:423)
    at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:490)
    at sun.reflect.GeneratedMethodAccessor14.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1170)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2233)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2342)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2342)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:423)
    at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:490)
    at sun.reflect.GeneratedMethodAccessor14.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1170)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2233)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2342)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2342)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:423)
    at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:490)
    at sun.reflect.GeneratedMethodAccessor14.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    ...

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/.../venv/lib/python3.6/site-packages/py4j/java_gateway.py", line 1067, in start
    self.socket.connect((self.address, self.port))
ConnectionRefusedError: [Errno 111] Connection refused
...

Когда я тренирую другую модель, такую ​​как Случайный Лес, это работает. Поэтому я предполагаю, что это потому, что я получил больше комбинаций гиперпараметров для тестирования в TrainValidationSplit для GBT (36), чем RF (4), так как списки, использованные для построения показа paramgrid:

RF_MAX_DEPTH = [5, 7]
RF_MAX_BINS = [32]
RF_NUM_TREES = [10, 5]
RF_IMPURITY = ['entropy']
RF_FEATURE_SUBSET_STRATEGY = ['auto']
RF_MIN_INSTANCES_PER_NODE = [1]
RF_MIN_INFO_GAIN = [0.0]
RF_SUBSAMPLING_RATE = [0.8]
RF_MAX_MEMORY_IN_MB = [8192]
RF_CACHE_NODE_IDS = [False]

GBT_MAX_DEPTH = [5, 7, 15, 20]
GBT_MAX_BINS = [32]
GBT_MAX_ITER = [200, 300, 1000]
GBT_STEP_SIZE = [0.05, 0.03, 0.2]
GBT_LOSS_TYPE = ['logistic']
GBT_MIN_INSTANCES_PER_NODE = [1]
GBT_MIN_INFO_GAIN = [0.0]
GBT_SUBSAMPLING_RATE = [0.8]
GBT_MAX_MEMORY_IN_MB = [8192]
GBT_CACHE_NODE_IDS = [True]

Почему TrainValidationSplit выдает эту ошибку и как я могу ее решить (кроме удаления некоторых возможностей гиперпараметров)?

Заранее благодарен за любую помощь:)

EDIT : я удалил некоторые гиперпараметры возможности (так что есть только одна возможная комбинация), и это работает. Так что проблема есть, но я действительно хочу, чтобы она работала.

EDIT 2 : я попытался добавить spark.driver.extraJavaOptions -XX:ThreadStackSize=81920, так как это, возможно, связано с настройками jvm (так как Java ошибка StackOverflow), но это только сделало его sh позже.

...