Повторная выборка MLR создает проблемы одного класса для классификации нескольких меток - PullRequest
1 голос
/ 28 мая 2019

Я пытаюсь измерить производительность многослойной классификации для некоторых классификаторов MLR с помощью перекрестной проверки

Я пытался использовать метод MLR resample или передать свое собственное подмножество, однако в обеих ситуациях выдается ошибка (из того, что я обнаружил, это происходит, когда подмножество, используемое для обучения, содержит только отдельные значения для некоторой метки)

Ниже приведен небольшой пример, где возникает эта проблема:

learner = mlr::makeLearner("classif.logreg")

learner = makeMultilabelClassifierChainsWrapper(learner)

data = data.frame(
    attr1 = c(1, 2, 2, 1, 2, 1, 2),
    attr2 = c(2, 1, 2, 2, 1, 2, 1),
    lab1 = c(FALSE, FALSE, TRUE, FALSE, FALSE, FALSE, FALSE),
    lab2 = c(FALSE, TRUE, FALSE, FALSE, FALSE, FALSE, FALSE))

task = mlr::makeMultilabelTask(data=data, target=c('lab1', 'lab2'))

есть два способа получить ошибку:

1

rDesc = makeResampleDesc("CV", iters = 3)

resample(learner, task, rDesc)

2

model = mlr::train(learner, task, subset=c(TRUE, FALSE, FALSE, TRUE, TRUE, TRUE, TRUE))

Сообщение об ошибке:

Ошибка в checkLearnerBeforeTrain (задача, ученик, вес): задача «lab1» - это проблема одного класса, но учащийся «classif.logreg» не поддерживает это!

1 Ответ

1 голос
/ 28 мая 2019

Поскольку в MLR нет учащихся, поддерживающих одноклассную (https://mlr.mlr -org.com / Articles / tutorial / integrated_learners.html ) классификацию и разделение данных, может потребоваться слишком много суеты (особеннодля наборов данных, таких как reutersk500), я создал оболочку для учеников, работающих с двумя классами, которая при задании с одним целевым классом всегда будет возвращать значение только для этого класса, а для других классов будет использоваться упакованный ученик:

(Этот кодбудет частью репозитория https://github.com/lychanl/ChainsOfClassification)

makeOneClassWrapper = function(learner) {
    learner = checkLearner(learner, type='classif')
    id = paste("classif.oneClassWrapper", getLearnerId(learner), sep = ".")
    packs = getLearnerPackages(learner)
    type = getLearnerType(learner)
    x = mlr::makeBaseWrapper(id, type, learner, packs, makeParamSet(),
        learner.subclass = c("OneClassWrapper"),
        model.subclass = c("OneClassWrapperModel"))
    x$type = "classif"
    x$properties = c(learner$properties, 'oneclass')
    return(x)
}

trainLearner.OneClassWrapper = function(.learner, .task, .subset = NULL, .weights = NULL, ...) {
    if (length(getTaskDesc(.task)$class.levels) <= 1) {
        x = list(oneclass=TRUE, value=.task$task.desc$positive)
        class(x) = "OneClassWrapperModel"
        return(makeChainModel(next.model = x, cl = c(.learner$model.subclass)))
    }

    model = train(.learner$next.learner, .task, .subset, .weights)

    x = list(oneclass=FALSE, model=model)
    class(x) = "OneClassWrapperModel"
    return(makeChainModel(next.model = x, cl = c(.learner$model.subclass)))
}

predictLearner.OneClassWrapper = function(.learner, .model, .newdata, ...) {
    .model = mlr::getLearnerModel(.model, more.unwrap = FALSE)

    if (.model$oneclass) {
        out = as.logical(rep(.model$value, nrow(.newdata)))
    }
    else {
        pred = predict(.model$model, newdata=.newdata)

        if (.learner$predict.type == "response") {
            out = getPredictionResponse(pred)
        } else {
            out = getPredictionProbabilities(pred, cl="TRUE")
        }
    }

    return(as.factor(out))
}

getLearnerProperties.OneClassWrapper = function(.learner) {
    return(.learner$properties)
}

isFailureModel.OneClassWrapperModel = function(model) {
    model = mlr::getLearnerModel(model, more.unwrap = FALSE)

  return(!model$oneclass && isFailureModel(model$model))
}

getFailureModelMsg.OneClassWrapperModel = function(model) {
    model = mlr::getLearnerModel(model, more.unwrap = FALSE)
  if (model$oneclass)
      return("")
  return(getFailureModelMsg(model$model))
}

getFailureModelDump.OneClassWrapperModel = function(model) {
    model = mlr::getLearnerModel(model, more.unwrap = FALSE)
  if (model$oneclass)
      return("")
  return(getFailureModelDump(model$model))
}

registerS3method("trainLearner", "<OneClassWrapper>", 
  trainLearner.OneClassWrapper)
registerS3method("getLearnerProperties", "<OneClassWrapper>", 
  getLearnerProperties.OneClassWrapper)
registerS3method("isFailureModel", "<OneClassWrapperModel>", 
  isFailureModel.OneClassWrapperModel)
registerS3method("getFailureModelMsg", "<OneClassWrapperModel>", 
  getFailureModelMsg.OneClassWrapperModel)
registerS3method("getFailureModelDump", "<OneClassWrapperModel>", 
  getFailureModelDump.OneClassWrapperModel)
...