Пропустить примерку финальной модели с кареткой - PullRequest
0 голосов
/ 24 октября 2018

Иногда, когда я подгоняю модель с помощью каретки, мне действительно просто интересно посмотреть, как она работает, используя выбранный мной метод повторной выборки (например, перекрестную проверку).

Когда меня не интересует "«Окончательная модель» построена на полных тренировочных данных, я бы хотел избежать ее подгонки.Это действительно просто экономия драгоценных минут несколько раз во время разработки.

Есть ли способ пропустить подгонку конечной модели при использовании каретки?Я не видел соответствующих аргументов в caret::trainControl или caret::train.

1 Ответ

0 голосов
/ 24 октября 2018

Кажется, действительно нет аргумента, который бы прямо достиг этого.Однако есть несколько вариантов решения.

  1. selectionFunction в качестве аргумента trainControl выбирает окончательную модель на основе производительности моделей-кандидатов (при этом только один кандидатнет настройки параметров) с точки зрения точности, среднеквадратичного значения и т. д. Установка selectionFunction как что-то вроде function(x, ...) NA или function(x, ...) NULL не удалась.Однако что-то вроде function(x, ...) -1 работает частично: предупреждение или ошибка не возвращаются, а окончательная модель попыталась установить .Окончательный результат, по-видимому, зависит от модели.

  2. Другой интересный аргумент trainControl представляет собой indexFinal:

    необязательный вектор целых чисел, указывающийкакие образцы используются, чтобы соответствовать окончательной модели после повторной выборки.Если NULL, то используется весь набор данных.

    Установка его на NA, по-видимому, не работает в большинстве моделей, кроме kNN.Установка его на что-то вроде 1:10 соответствует окончательной модели, , если параметров достаточно мало, используя только эти десять наблюдений.Следовательно, его установка на что-то вроде 1:100 должна работать во многих случаях и занимать мало времени.

  3. Вы, конечно, можете изменить саму функцию train.Далее я только добавляю аргумент fitFinal, который по умолчанию равен TRUE, и проверяю, равен ли он TRUE, чтобы соответствовать конечной модели.Если fitFinal == FALSE, то

    finalModel <- list(fit = NULL, preProc = NULL)
    finalTime <- 0
    

    Все остальное работает гладко.Что касается перезаписи фактической функции train.default, то после этого вы должны запустить

    environment(myTrain) <- environment(caret:::train.default)
    assignInNamespace("train.default", myTrain, ns = "caret")
    

    Итак, у нас есть

    myTrain <- function (x, y, method = "rf", preProcess = NULL, ..., weights = NULL, fitFinal = TRUE,
                         metric = ifelse(is.factor(y), "Accuracy", "RMSE"), maximize = ifelse(metric %in%
                                                                                                c("RMSE", "logLoss", "MAE"), FALSE, TRUE), trControl = trainControl(),
                         tuneGrid = NULL, tuneLength = ifelse(trControl$method ==
                                                                "none", 1, 3))
    {
      startTime <- proc.time()
      rs_seed <- sample.int(.Machine$integer.max, 1L)
      if (is.null(colnames(x)))
        stop("Please use column names for `x`", call. = FALSE)
      if (is.character(y))
        y <- as.factor(y)
      if (!is.numeric(y) & !is.factor(y)) {
        stop("Please make sure `y` is a factor or numeric value.",
             call. = FALSE)
      }
      if (is.list(method)) {
        minNames <- c("library", "type", "parameters", "grid",
                      "fit", "predict", "prob")
        nameCheck <- minNames %in% names(method)
        if (!all(nameCheck))
          stop(paste("some required components are missing:",
                     paste(minNames[!nameCheck], collapse = ", ")),
               call. = FALSE)
        models <- method
        method <- "custom"
      }
      else {
        models <- getModelInfo(method, regex = FALSE)[[1]]
        if (length(models) == 0)
          stop(paste("Model", method, "is not in caret's built-in library"),
               call. = FALSE)
      }
      checkInstall(models$library)
      for (i in seq(along = models$library)) do.call("requireNamespaceQuietStop",
                                                     list(package = models$library[i]))
      if (any(names(models) == "check") && is.function(models$check)) {
        software_check <- models$check(models$library)
      }
      paramNames <- as.character(models$parameters$parameter)
      funcCall <- match.call(expand.dots = TRUE)
      modelType <- get_model_type(y)
      if (!(modelType %in% models$type))
        stop(paste("wrong model type for", tolower(modelType)),
             call. = FALSE)
      if (grepl("^svm", method) & grepl("String$", method)) {
        if (is.vector(x) && is.character(x)) {
          stop("'x' should be a character matrix with a single column for string kernel methods",
               call. = FALSE)
        }
        if (is.matrix(x) && is.numeric(x)) {
          stop("'x' should be a character matrix with a single column for string kernel methods",
               call. = FALSE)
        }
        if (is.data.frame(x)) {
          stop("'x' should be a character matrix with a single column for string kernel methods",
               call. = FALSE)
        }
      }
      if (modelType == "Regression" & length(unique(y)) == 2)
        warning(paste("You are trying to do regression and your outcome only has",
                      "two possible values Are you trying to do classification?",
                      "If so, use a 2 level factor as your outcome column."))
      if (modelType != "Classification" & !is.null(trControl$sampling))
        stop("sampling methods are only implemented for classification problems",
             call. = FALSE)
      if (!is.null(trControl$sampling)) {
        trControl$sampling <- parse_sampling(trControl$sampling)
      }
      if (any(class(x) == "data.table"))
        x <- as.data.frame(x)
      check_dims(x = x, y = y)
      n <- if (class(y)[1] == "Surv")
        nrow(y)
      else length(y)
      parallel_check("RWeka", models)
      parallel_check("keras", models)
      if (!is.null(preProcess) && !(all(names(preProcess) %in%
                                        ppMethods)))
        stop(paste("pre-processing methods are limited to:",
                   paste(ppMethods, collapse = ", ")), call. = FALSE)
      if (modelType == "Classification") {
        classLevels <- levels(y)
        attributes(classLevels) <- list(ordered = is.ordered(y))
        xtab <- table(y)
        if (any(xtab == 0)) {
          xtab_msg <- paste("'", names(xtab)[xtab == 0], "'",
                            collapse = ", ", sep = "")
          stop(paste("One or more factor levels in the outcome has no data:",
                     xtab_msg), call. = FALSE)
        }
        if (trControl$classProbs && any(classLevels != make.names(classLevels))) {
          stop(paste("At least one of the class levels is not a valid R variable name;",
                     "This will cause errors when class probabilities are generated because",
                     "the variables names will be converted to ",
                     paste(make.names(classLevels), collapse = ", "),
                     ". Please use factor levels that can be used as valid R variable names",
                     " (see ?make.names for help)."), call. = FALSE)
        }
        if (metric %in% c("RMSE", "Rsquared"))
          stop(paste("Metric", metric, "not applicable for classification models"),
               call. = FALSE)
        if (!trControl$classProbs && metric == "ROC")
          stop(paste("Class probabilities are needed to score models using the",
                     "area under the ROC curve. Set `classProbs = TRUE`",
                     "in the trainControl() function."), call. = FALSE)
        if (trControl$classProbs) {
          if (!is.function(models$prob)) {
            warning("Class probabilities were requested for a model that does not implement them")
            trControl$classProbs <- FALSE
          }
        }
      }
      else {
        if (metric %in% c("Accuracy", "Kappa"))
          stop(paste("Metric", metric, "not applicable for regression models"),
               call. = FALSE)
        classLevels <- NA
        if (trControl$classProbs) {
          warning("cannnot compute class probabilities for regression")
          trControl$classProbs <- FALSE
        }
      }
      if (trControl$method == "oob" & is.null(models$oob))
        stop("Out of bag estimates are not implemented for this model",
             call. = FALSE)
      trControl <- withr::with_seed(rs_seed, make_resamples(trControl,
                                                            outcome = y))
      if (is.logical(trControl$savePredictions)) {
        trControl$savePredictions <- if (trControl$savePredictions)
          "all"
        else "none"
      }
      else {
        if (!(trControl$savePredictions %in% c("all", "final",
                                               "none")))
          stop("`savePredictions` should be either logical or \"all\", \"final\" or \"none\"",
               call. = FALSE)
      }
      if (!is.null(preProcess)) {
        ppOpt <- list(options = preProcess)
        if (length(trControl$preProcOptions) > 0)
          ppOpt <- c(ppOpt, trControl$preProcOptions)
      }
      else ppOpt <- NULL
      if (is.null(tuneGrid)) {
        if (!is.null(ppOpt) && length(models$parameters$parameter) >
            1 && as.character(models$parameters$parameter) !=
            "parameter") {
          pp <- list(method = ppOpt$options)
          if ("ica" %in% pp$method)
            pp$n.comp <- ppOpt$ICAcomp
          if ("pca" %in% pp$method)
            pp$thresh <- ppOpt$thresh
          if ("knnImpute" %in% pp$method)
            pp$k <- ppOpt$k
          pp$x <- x
          ppObj <- do.call("preProcess", pp)
          tuneGrid <- models$grid(x = predict(ppObj, x), y = y,
                                  len = tuneLength, search = trControl$search)
          rm(ppObj, pp)
        }
        else {
          tuneGrid <- models$grid(x = x, y = y, len = tuneLength,
                                  search = trControl$search)
          if (trControl$search != "grid" && tuneLength < nrow(tuneGrid))
            tuneGrid <- tuneGrid[1:tuneLength, , drop = FALSE]
        }
      }
      if (grepl("adaptive", trControl$method) & nrow(tuneGrid) ==
          1) {
        stop(paste("For adaptive resampling, there needs to be more than one",
                   "tuning parameter for evaluation"), call. = FALSE)
      }
      dotNames <- hasDots(tuneGrid, models)
      if (dotNames)
        colnames(tuneGrid) <- gsub("^\\.", "", colnames(tuneGrid))
      tuneNames <- as.character(models$parameters$parameter)
      goodNames <- all.equal(sort(tuneNames), sort(names(tuneGrid)))
      if (!is.logical(goodNames) || !goodNames) {
        stop(paste("The tuning parameter grid should have columns",
                   paste(tuneNames, collapse = ", ", sep = "")), call. = FALSE)
      }
      if (trControl$method == "none" && nrow(tuneGrid) != 1)
        stop("Only one model should be specified in tuneGrid with no resampling",
             call. = FALSE)
      trControl$yLimits <- if (is.numeric(y))
        get_range(y)
      else NULL
      if (trControl$method != "none") {
        if (is.function(models$loop) && nrow(tuneGrid) > 1) {
          trainInfo <- models$loop(tuneGrid)
          if (!all(c("loop", "submodels") %in% names(trainInfo)))
            stop("The 'loop' function should produce a list with elements 'loop' and 'submodels'",
                 call. = FALSE)
          lengths <- unlist(lapply(trainInfo$submodels, nrow))
          if (all(lengths == 0))
            trainInfo$submodels <- NULL
        }
        else trainInfo <- list(loop = tuneGrid)
        num_rs <- if (trControl$method != "oob")
          length(trControl$index)
        else 1L
        if (trControl$method %in% c("boot632", "optimism_boot",
                                    "boot_all"))
          num_rs <- num_rs + 1L
        if (is.null(trControl$seeds) || all(is.na(trControl$seeds))) {
          seeds <- sample.int(n = 1000000L, size = num_rs *
                                nrow(trainInfo$loop) + 1L)
          seeds <- lapply(seq(from = 1L, to = length(seeds),
                              by = nrow(trainInfo$loop)), function(x) {
                                seeds[x:(x + nrow(trainInfo$loop) - 1L)]
                              })
          seeds[[num_rs + 1L]] <- seeds[[num_rs + 1L]][1L]
          trControl$seeds <- seeds
        }
        else {
          if (!(length(trControl$seeds) == 1 && is.na(trControl$seeds))) {
            numSeeds <- unlist(lapply(trControl$seeds, length))
            badSeed <- (length(trControl$seeds) < num_rs +
                          1L) || (any(numSeeds[-length(numSeeds)] < nrow(trainInfo$loop))) ||
              (numSeeds[length(numSeeds)] < 1L)
            if (badSeed)
              stop(paste("Bad seeds: the seed object should be a list of length",
                         num_rs + 1, "with", num_rs, "integer vectors of size",
                         nrow(trainInfo$loop), "and the last list element having at least a",
                         "single integer"), call. = FALSE)
            if (any(is.na(unlist(trControl$seeds))))
              stop("At least one seed is missing (NA)", call. = FALSE)
          }
        }
        if (trControl$method == "oob") {
          perfNames <- metric
        }
        else {
          testSummary <- evalSummaryFunction(y, wts = weights,
                                             ctrl = trControl, lev = classLevels, metric = metric,
                                             method = method)
          perfNames <- names(testSummary)
        }
        if (!(metric %in% perfNames)) {
          oldMetric <- metric
          metric <- perfNames[1]
          warning(paste("The metric \"", oldMetric, "\" was not in ",
                        "the result set. ", metric, " will be used instead.",
                        sep = ""))
        }
        if (trControl$method == "oob") {
          tmp <- oobTrainWorkflow(x = x, y = y, wts = weights,
                                  info = trainInfo, method = models, ppOpts = preProcess,
                                  ctrl = trControl, lev = classLevels, ...)
          performance <- tmp
          perfNames <- colnames(performance)
          perfNames <- perfNames[!(perfNames %in% as.character(models$parameters$parameter))]
          if (!(metric %in% perfNames)) {
            oldMetric <- metric
            metric <- perfNames[1]
            warning(paste("The metric \"", oldMetric, "\" was not in ",
                          "the result set. ", metric, " will be used instead.",
                          sep = ""))
          }
        }
        else {
          if (trControl$method == "LOOCV") {
            tmp <- looTrainWorkflow(x = x, y = y, wts = weights,
                                    info = trainInfo, method = models, ppOpts = preProcess,
                                    ctrl = trControl, lev = classLevels, ...)
            performance <- tmp$performance
          }
          else {
            if (!grepl("adapt", trControl$method)) {
              tmp <- nominalTrainWorkflow(x = x, y = y, wts = weights,
                                          info = trainInfo, method = models, ppOpts = preProcess,
                                          ctrl = trControl, lev = classLevels, ...)
              performance <- tmp$performance
              resampleResults <- tmp$resample
            }
            else {
              tmp <- adaptiveWorkflow(x = x, y = y, wts = weights,
                                      info = trainInfo, method = models, ppOpts = preProcess,
                                      ctrl = trControl, lev = classLevels, metric = metric,
                                      maximize = maximize, ...)
              performance <- tmp$performance
              resampleResults <- tmp$resample
            }
          }
        }
        trControl$indexExtra <- NULL
        if (!(trControl$method %in% c("LOOCV", "oob"))) {
          if (modelType == "Classification" && length(grep("^\\cell",
                                                           colnames(resampleResults))) > 0) {
            resampledCM <- resampleResults[, !(names(resampleResults) %in%
                                                 perfNames)]
            resampleResults <- resampleResults[, -grep("^\\cell",
                                                       colnames(resampleResults))]
          }
          else resampledCM <- NULL
        }
        else resampledCM <- NULL
        if (trControl$verboseIter) {
          cat("Aggregating results\n")
          flush.console()
        }
        perfCols <- names(performance)
        perfCols <- perfCols[!(perfCols %in% paramNames)]
        if (all(is.na(performance[, metric]))) {
          cat(paste("Something is wrong; all the", metric,
                    "metric values are missing:\n"))
          print(summary(performance[, perfCols[!grepl("SD$",
                                                      perfCols)], drop = FALSE]))
          stop("Stopping", call. = FALSE)
        }
        if (!is.null(models$sort))
          performance <- models$sort(performance)
        if (any(is.na(performance[, metric])))
          warning("missing values found in aggregated results")
        if (trControl$verboseIter && nrow(performance) > 1) {
          cat("Selecting tuning parameters\n")
          flush.console()
        }
        selectClass <- class(trControl$selectionFunction)[1]
        if (grepl("adapt", trControl$method)) {
          perf_check <- subset(performance, .B == max(performance$.B))
        }
        else perf_check <- performance
        if (selectClass == "function") {
          bestIter <- trControl$selectionFunction(x = perf_check,
                                                  metric = metric, maximize = maximize)
        }
        else {
          if (trControl$selectionFunction == "oneSE") {
            bestIter <- oneSE(perf_check, metric, length(trControl$index),
                              maximize)
          }
          else {
            bestIter <- do.call(trControl$selectionFunction,
                                list(x = perf_check, metric = metric, maximize = maximize))
          }
        }
        if (is.na(bestIter) || length(bestIter) != 1)
          stop("final tuning parameters could not be determined",
               call. = FALSE)
        if (grepl("adapt", trControl$method)) {
          best_perf <- perf_check[bestIter, as.character(models$parameters$parameter),
                                  drop = FALSE]
          performance$order <- 1:nrow(performance)
          bestIter <- merge(performance, best_perf)$order
          performance$order <- NULL
        }
        bestTune <- performance[bestIter, paramNames, drop = FALSE]
      }
      else {
        bestTune <- tuneGrid
        performance <- evalSummaryFunction(y, wts = weights,
                                           ctrl = trControl, lev = classLevels, metric = metric,
                                           method = method)
        perfNames <- names(performance)
        performance <- as.data.frame(t(performance))
        performance <- cbind(performance, tuneGrid)
        performance <- performance[-1, , drop = FALSE]
        tmp <- resampledCM <- NULL
      }
      if (!(trControl$method %in% c("LOOCV", "oob", "none"))) {
        byResample <- switch(trControl$returnResamp, none = NULL,
                             all = {
                               out <- resampleResults
                               colnames(out) <- gsub("^\\.", "", colnames(out))
                               out
                             }, final = {
                               out <- merge(bestTune, resampleResults)
                               out <- out[, !(names(out) %in% names(tuneGrid)),
                                          drop = FALSE]
                               out
                             })
      }
      else {
        byResample <- NULL
      }
      orderList <- list()
      for (i in seq(along = paramNames)) orderList[[i]] <- performance[,
                                                                       paramNames[i]]
      performance <- performance[do.call("order", orderList), ]
      if (trControl$verboseIter) {
        bestText <- paste(paste(names(bestTune), "=", format(bestTune,
                                                             digits = 3)), collapse = ", ")
        if (nrow(performance) == 1)
          bestText <- "final model"
        cat("Fitting", bestText, "on full training set\n")
        flush.console()
      }
      indexFinal <- if (is.null(trControl$indexFinal))
        seq(along = y)
      else trControl$indexFinal
      if (!(length(trControl$seeds) == 1 && is.na(trControl$seeds)))
        set.seed(trControl$seeds[[length(trControl$seeds)]][1])
      if (fitFinal) {
        finalTime <- system.time(finalModel <- createModel(x = subset_x(x,
                                                                        indexFinal), y = y[indexFinal], wts = weights[indexFinal],
                                                           method = models, tuneValue = bestTune, obsLevels = classLevels,
                                                           pp = ppOpt, last = TRUE, classProbs = trControl$classProbs,
                                                           sampling = trControl$sampling, ...))
      } else {
        finalModel <- list(fit = NULL, preProc = NULL)
        finalTime <- 0
      }
      if (trControl$trim && !is.null(models$trim)) {
        if (trControl$verboseIter)
          old_size <- object.size(finalModel$fit)
        finalModel$fit <- models$trim(finalModel$fit)
        if (trControl$verboseIter) {
          new_size <- object.size(finalModel$fit)
          reduction <- format(old_size - new_size, units = "Mb")
          if (reduction == "0 Mb")
            reduction <- "< 0 Mb"
          p_reduction <- (unclass(old_size) - unclass(new_size))/unclass(old_size) *
            100
          p_reduction <- if (p_reduction < 1)
            "< 1%"
          else paste0(round(p_reduction, 0), "%")
          cat("Final model footprint reduced by", reduction,
              "or", p_reduction, "\n")
        }
      }
      pp <- finalModel$preProc
      finalModel <- finalModel$fit
      if (method == "pls")
        finalModel$bestIter <- bestTune
      if (method == "glmnet")
        finalModel$lambdaOpt <- bestTune$lambda
      if (trControl$returnData) {
        outData <- if (!is.data.frame(x))
          try(as.data.frame(x), silent = TRUE)
        else x
        if (inherits(outData, "try-error")) {
          warning("The training data could not be converted to a data frame for saving")
          outData <- NULL
        }
        else {
          outData$.outcome <- y
          if (!is.null(weights))
            outData$.weights <- weights
        }
      }
      else outData <- NULL
      if (trControl$savePredictions == "final")
        tmp$predictions <- merge(bestTune, tmp$predictions)
      endTime <- proc.time()
      times <- list(everything = endTime - startTime, final = finalTime)
      out <- structure(list(method = method, modelInfo = models,
                            modelType = modelType, results = performance, pred = tmp$predictions,
                            bestTune = bestTune, call = funcCall, dots = list(...),
                            metric = metric, control = trControl, finalModel = finalModel,
                            preProcess = pp, trainingData = outData, resample = byResample,
                            resampledCM = resampledCM, perfNames = perfNames, maximize = maximize,
                            yLimits = trControl$yLimits, times = times, levels = classLevels),
                       class = "train")
      trControl$yLimits <- NULL
      if (trControl$timingSamps > 0) {
        pData <- x[sample(1:nrow(x), trControl$timingSamps, replace = TRUE),
                   , drop = FALSE]
        out$times$prediction <- system.time(predict(out, pData))
      }
      else out$times$prediction <- rep(NA, 3)
      out
    }
    

    , что дает

    data(iris)
    TrainData <- iris[,1:4]
    TrainClasses <- iris[,5]
    
    knnFit1 <- train(TrainData, TrainClasses,
                     method = "knn",
                     preProcess = c("center", "scale"),
                     tuneLength = 10,
                     trControl = trainControl(method = "cv"), fitFinal = FALSE)
    knnFit1$finalModel
    # NULL
    
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...