Ярлык софт макс Amazon SageMaker должен быть в [0, num_class) - PullRequest
0 голосов
/ 04 апреля 2020

Я начинаю с проблемы мультиклассовой классификации с использованием Xgboost и экспериментирую с SageMaker HyperparameterTuner. Похоже, все настроено правильно, но я продолжаю получать статусы с ошибками Когда я проверяю журналы CloudWatch, я вижу, что это проблема. "SoftmaxMultiClassObj: метка должна быть в [0, num_class)." Я знаю, сколько классов у меня в данных 5 (0,1,2,3,4), и они находятся в первом столбце наборов данных обучения и проверки.

Ниже приведены мои конфигурации:

xgb = sagemaker.estimator.Estimator(container,
                                        role, 
                                        train_instance_count=1, 
                                        train_instance_type='ml.m4.xlarge',
                                        output_path='s3://{}/{}/output'.format(bucket, prefix),
                                        sagemaker_session=sess)

xgb.set_hyperparameters(eval_metric='merror', #'rmse',
                            objective = 'multi:softmax',#'binary:logistic',
                            num_round=100,
                            rate_drop=0.3,
                            num_class= 5)

hyperparameter_ranges = {'eta': ContinuousParameter(0, 1),
                        'min_child_weight': ContinuousParameter(1, 10),
                        'alpha': ContinuousParameter(0, 2),
                        'max_depth': IntegerParameter(1, 10)}

objective_metric_name = 'validation:accuracy'

tuner = HyperparameterTuner(xgb,
                            objective_metric_name,
                            hyperparameter_ranges,
                            max_jobs=20,
                            max_parallel_jobs=3)

s3_input_train = sagemaker.s3_input(s3_data='s3://{}/{}/TrainingData'.format(bucket, prefix), content_type='csv')
s3_input_validation = sagemaker.s3_input(s3_data='s3://{}/{}/ValidationData'.format(bucket, prefix), content_type='csv')

tuner.fit({'train': s3_input_train, 'validation': s3_input_validation}, include_cls_metadata=False)

И для хороших мер здесь приведен набор для первого столбца набора данных, показывающий, что действительно существует 5 классов.

data=pd.read_csv('TrainingData.csv')
set(data['0'])

{0, 1, 2, 3, 4}

Я перезагружал учебные задания SageMaker несколько раз и сам ноутбук, но продолжаю получать ту же ошибку. Что ты думаешь попробовать?

...