Параметр one_hot_explicit для h2o python вызывает ошибку - PullRequest
0 голосов
/ 09 мая 2018

При обучении модели в h2o v3.10 с использованием библиотеки python h2o возникает ошибка при попытке установить one_hot_explicit в качестве параметра categoryorical_encoding .

encoding = "enum"

gbm = H2OGradientBoostingEstimator(
        categorical_encoding = encoding)

gbm.train(x, y,train_h2o_df,test_h2o_df)

Работает нормально, и модель использует enum categoryorical_encoding, но когда:

encoding = "one_hot_explicit"

или

encoding = "OneHotExplicit"

возникает следующая ошибка:

gbm Model Build progress: | (failed)
....
OSError: Job with key $03017f00000132d4ffffffff$_bde8fcb4777df7e0be1199bf590a47f9 failed with an exception: java.lang.AssertionError
stacktrace: 
java.lang.AssertionError
at hex.ModelBuilder.init(ModelBuilder.java:958)
at hex.tree.SharedTree.init(SharedTree.java:78)
at hex.tree.gbm.GBM.init(GBM.java:57)
at hex.tree.SharedTree$Driver.computeImpl(SharedTree.java:159)
at hex.ModelBuilder$Driver.compute2(ModelBuilder.java:169)
at water.H2O$H2OCountedCompleter.compute(H2O.java:1203)
at jsr166y.CountedCompleter.exec(CountedCompleter.java:468)
at jsr166y.ForkJoinTask.doExec(ForkJoinTask.java:263)
at jsr166y.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:974)
at jsr166y.ForkJoinPool.runWorker(ForkJoinPool.java:1477)
at jsr166y.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:104)

Есть какая-то зависимость, по которой я скучаю или это ошибка?

1 Ответ

0 голосов
/ 09 мая 2018

ваш выбор кодировки должен работать, хотя вы можете обновить его до последней стабильной версии H2O. Вот фрагмент кода, который вы можете запустить, который работает, и проверьте, работает ли он для вас. Если это работает, тогда вы можете попытаться определить разницу между вашим предыдущим кодом и приведенным ниже примером.

import h2o
from h2o.estimators.gbm import H2OGradientBoostingEstimator
h2o.init()

# import the airlines dataset:
# This dataset is used to classify whether a flight will be delayed 'YES' or not "NO"
# original data can be found at http://www.transtats.bts.gov/
airlines= h2o.import_file("https://s3.amazonaws.com/h2o-public-test-data/smalldata/airlines/allyears2k_headers.zip")

# convert columns to factors
airlines["Year"]= airlines["Year"].asfactor()
airlines["Month"]= airlines["Month"].asfactor()
airlines["DayOfWeek"] = airlines["DayOfWeek"].asfactor()

# set the predictor names and the response column name
predictors = ["Origin", "Dest", "Year", "DayOfWeek", "Month", "Distance"]
response = "IsDepDelayed"

# split into train and validation sets
train, valid= airlines.split_frame(ratios = [.8], seed = 1234)

# try using the `categorical_encoding` parameter:
encoding = "one_hot_explicit"

# initialize the estimator
airlines_gbm = H2OGradientBoostingEstimator(categorical_encoding = encoding, seed =1234)

# then train the model
airlines_gbm.train(x = predictors, y = response, training_frame = train, validation_frame = valid)

# print the auc for the validation set
airlines_gbm.auc(valid=True)
...