Вы, конечно, можете изменить саму функцию train
.Далее я только добавляю аргумент fitFinal
, который по умолчанию равен TRUE
, и проверяю, равен ли он TRUE
, чтобы соответствовать конечной модели.Если fitFinal == FALSE
, то
finalModel <- list(fit = NULL, preProc = NULL)
finalTime <- 0
Все остальное работает гладко.Что касается перезаписи фактической функции train.default
, то после этого вы должны запустить
environment(myTrain) <- environment(caret:::train.default)
assignInNamespace("train.default", myTrain, ns = "caret")
Итак, у нас есть
myTrain <- function (x, y, method = "rf", preProcess = NULL, ..., weights = NULL, fitFinal = TRUE,
metric = ifelse(is.factor(y), "Accuracy", "RMSE"), maximize = ifelse(metric %in%
c("RMSE", "logLoss", "MAE"), FALSE, TRUE), trControl = trainControl(),
tuneGrid = NULL, tuneLength = ifelse(trControl$method ==
"none", 1, 3))
{
startTime <- proc.time()
rs_seed <- sample.int(.Machine$integer.max, 1L)
if (is.null(colnames(x)))
stop("Please use column names for `x`", call. = FALSE)
if (is.character(y))
y <- as.factor(y)
if (!is.numeric(y) & !is.factor(y)) {
stop("Please make sure `y` is a factor or numeric value.",
call. = FALSE)
}
if (is.list(method)) {
minNames <- c("library", "type", "parameters", "grid",
"fit", "predict", "prob")
nameCheck <- minNames %in% names(method)
if (!all(nameCheck))
stop(paste("some required components are missing:",
paste(minNames[!nameCheck], collapse = ", ")),
call. = FALSE)
models <- method
method <- "custom"
}
else {
models <- getModelInfo(method, regex = FALSE)[[1]]
if (length(models) == 0)
stop(paste("Model", method, "is not in caret's built-in library"),
call. = FALSE)
}
checkInstall(models$library)
for (i in seq(along = models$library)) do.call("requireNamespaceQuietStop",
list(package = models$library[i]))
if (any(names(models) == "check") && is.function(models$check)) {
software_check <- models$check(models$library)
}
paramNames <- as.character(models$parameters$parameter)
funcCall <- match.call(expand.dots = TRUE)
modelType <- get_model_type(y)
if (!(modelType %in% models$type))
stop(paste("wrong model type for", tolower(modelType)),
call. = FALSE)
if (grepl("^svm", method) & grepl("String$", method)) {
if (is.vector(x) && is.character(x)) {
stop("'x' should be a character matrix with a single column for string kernel methods",
call. = FALSE)
}
if (is.matrix(x) && is.numeric(x)) {
stop("'x' should be a character matrix with a single column for string kernel methods",
call. = FALSE)
}
if (is.data.frame(x)) {
stop("'x' should be a character matrix with a single column for string kernel methods",
call. = FALSE)
}
}
if (modelType == "Regression" & length(unique(y)) == 2)
warning(paste("You are trying to do regression and your outcome only has",
"two possible values Are you trying to do classification?",
"If so, use a 2 level factor as your outcome column."))
if (modelType != "Classification" & !is.null(trControl$sampling))
stop("sampling methods are only implemented for classification problems",
call. = FALSE)
if (!is.null(trControl$sampling)) {
trControl$sampling <- parse_sampling(trControl$sampling)
}
if (any(class(x) == "data.table"))
x <- as.data.frame(x)
check_dims(x = x, y = y)
n <- if (class(y)[1] == "Surv")
nrow(y)
else length(y)
parallel_check("RWeka", models)
parallel_check("keras", models)
if (!is.null(preProcess) && !(all(names(preProcess) %in%
ppMethods)))
stop(paste("pre-processing methods are limited to:",
paste(ppMethods, collapse = ", ")), call. = FALSE)
if (modelType == "Classification") {
classLevels <- levels(y)
attributes(classLevels) <- list(ordered = is.ordered(y))
xtab <- table(y)
if (any(xtab == 0)) {
xtab_msg <- paste("'", names(xtab)[xtab == 0], "'",
collapse = ", ", sep = "")
stop(paste("One or more factor levels in the outcome has no data:",
xtab_msg), call. = FALSE)
}
if (trControl$classProbs && any(classLevels != make.names(classLevels))) {
stop(paste("At least one of the class levels is not a valid R variable name;",
"This will cause errors when class probabilities are generated because",
"the variables names will be converted to ",
paste(make.names(classLevels), collapse = ", "),
". Please use factor levels that can be used as valid R variable names",
" (see ?make.names for help)."), call. = FALSE)
}
if (metric %in% c("RMSE", "Rsquared"))
stop(paste("Metric", metric, "not applicable for classification models"),
call. = FALSE)
if (!trControl$classProbs && metric == "ROC")
stop(paste("Class probabilities are needed to score models using the",
"area under the ROC curve. Set `classProbs = TRUE`",
"in the trainControl() function."), call. = FALSE)
if (trControl$classProbs) {
if (!is.function(models$prob)) {
warning("Class probabilities were requested for a model that does not implement them")
trControl$classProbs <- FALSE
}
}
}
else {
if (metric %in% c("Accuracy", "Kappa"))
stop(paste("Metric", metric, "not applicable for regression models"),
call. = FALSE)
classLevels <- NA
if (trControl$classProbs) {
warning("cannnot compute class probabilities for regression")
trControl$classProbs <- FALSE
}
}
if (trControl$method == "oob" & is.null(models$oob))
stop("Out of bag estimates are not implemented for this model",
call. = FALSE)
trControl <- withr::with_seed(rs_seed, make_resamples(trControl,
outcome = y))
if (is.logical(trControl$savePredictions)) {
trControl$savePredictions <- if (trControl$savePredictions)
"all"
else "none"
}
else {
if (!(trControl$savePredictions %in% c("all", "final",
"none")))
stop("`savePredictions` should be either logical or \"all\", \"final\" or \"none\"",
call. = FALSE)
}
if (!is.null(preProcess)) {
ppOpt <- list(options = preProcess)
if (length(trControl$preProcOptions) > 0)
ppOpt <- c(ppOpt, trControl$preProcOptions)
}
else ppOpt <- NULL
if (is.null(tuneGrid)) {
if (!is.null(ppOpt) && length(models$parameters$parameter) >
1 && as.character(models$parameters$parameter) !=
"parameter") {
pp <- list(method = ppOpt$options)
if ("ica" %in% pp$method)
pp$n.comp <- ppOpt$ICAcomp
if ("pca" %in% pp$method)
pp$thresh <- ppOpt$thresh
if ("knnImpute" %in% pp$method)
pp$k <- ppOpt$k
pp$x <- x
ppObj <- do.call("preProcess", pp)
tuneGrid <- models$grid(x = predict(ppObj, x), y = y,
len = tuneLength, search = trControl$search)
rm(ppObj, pp)
}
else {
tuneGrid <- models$grid(x = x, y = y, len = tuneLength,
search = trControl$search)
if (trControl$search != "grid" && tuneLength < nrow(tuneGrid))
tuneGrid <- tuneGrid[1:tuneLength, , drop = FALSE]
}
}
if (grepl("adaptive", trControl$method) & nrow(tuneGrid) ==
1) {
stop(paste("For adaptive resampling, there needs to be more than one",
"tuning parameter for evaluation"), call. = FALSE)
}
dotNames <- hasDots(tuneGrid, models)
if (dotNames)
colnames(tuneGrid) <- gsub("^\\.", "", colnames(tuneGrid))
tuneNames <- as.character(models$parameters$parameter)
goodNames <- all.equal(sort(tuneNames), sort(names(tuneGrid)))
if (!is.logical(goodNames) || !goodNames) {
stop(paste("The tuning parameter grid should have columns",
paste(tuneNames, collapse = ", ", sep = "")), call. = FALSE)
}
if (trControl$method == "none" && nrow(tuneGrid) != 1)
stop("Only one model should be specified in tuneGrid with no resampling",
call. = FALSE)
trControl$yLimits <- if (is.numeric(y))
get_range(y)
else NULL
if (trControl$method != "none") {
if (is.function(models$loop) && nrow(tuneGrid) > 1) {
trainInfo <- models$loop(tuneGrid)
if (!all(c("loop", "submodels") %in% names(trainInfo)))
stop("The 'loop' function should produce a list with elements 'loop' and 'submodels'",
call. = FALSE)
lengths <- unlist(lapply(trainInfo$submodels, nrow))
if (all(lengths == 0))
trainInfo$submodels <- NULL
}
else trainInfo <- list(loop = tuneGrid)
num_rs <- if (trControl$method != "oob")
length(trControl$index)
else 1L
if (trControl$method %in% c("boot632", "optimism_boot",
"boot_all"))
num_rs <- num_rs + 1L
if (is.null(trControl$seeds) || all(is.na(trControl$seeds))) {
seeds <- sample.int(n = 1000000L, size = num_rs *
nrow(trainInfo$loop) + 1L)
seeds <- lapply(seq(from = 1L, to = length(seeds),
by = nrow(trainInfo$loop)), function(x) {
seeds[x:(x + nrow(trainInfo$loop) - 1L)]
})
seeds[[num_rs + 1L]] <- seeds[[num_rs + 1L]][1L]
trControl$seeds <- seeds
}
else {
if (!(length(trControl$seeds) == 1 && is.na(trControl$seeds))) {
numSeeds <- unlist(lapply(trControl$seeds, length))
badSeed <- (length(trControl$seeds) < num_rs +
1L) || (any(numSeeds[-length(numSeeds)] < nrow(trainInfo$loop))) ||
(numSeeds[length(numSeeds)] < 1L)
if (badSeed)
stop(paste("Bad seeds: the seed object should be a list of length",
num_rs + 1, "with", num_rs, "integer vectors of size",
nrow(trainInfo$loop), "and the last list element having at least a",
"single integer"), call. = FALSE)
if (any(is.na(unlist(trControl$seeds))))
stop("At least one seed is missing (NA)", call. = FALSE)
}
}
if (trControl$method == "oob") {
perfNames <- metric
}
else {
testSummary <- evalSummaryFunction(y, wts = weights,
ctrl = trControl, lev = classLevels, metric = metric,
method = method)
perfNames <- names(testSummary)
}
if (!(metric %in% perfNames)) {
oldMetric <- metric
metric <- perfNames[1]
warning(paste("The metric \"", oldMetric, "\" was not in ",
"the result set. ", metric, " will be used instead.",
sep = ""))
}
if (trControl$method == "oob") {
tmp <- oobTrainWorkflow(x = x, y = y, wts = weights,
info = trainInfo, method = models, ppOpts = preProcess,
ctrl = trControl, lev = classLevels, ...)
performance <- tmp
perfNames <- colnames(performance)
perfNames <- perfNames[!(perfNames %in% as.character(models$parameters$parameter))]
if (!(metric %in% perfNames)) {
oldMetric <- metric
metric <- perfNames[1]
warning(paste("The metric \"", oldMetric, "\" was not in ",
"the result set. ", metric, " will be used instead.",
sep = ""))
}
}
else {
if (trControl$method == "LOOCV") {
tmp <- looTrainWorkflow(x = x, y = y, wts = weights,
info = trainInfo, method = models, ppOpts = preProcess,
ctrl = trControl, lev = classLevels, ...)
performance <- tmp$performance
}
else {
if (!grepl("adapt", trControl$method)) {
tmp <- nominalTrainWorkflow(x = x, y = y, wts = weights,
info = trainInfo, method = models, ppOpts = preProcess,
ctrl = trControl, lev = classLevels, ...)
performance <- tmp$performance
resampleResults <- tmp$resample
}
else {
tmp <- adaptiveWorkflow(x = x, y = y, wts = weights,
info = trainInfo, method = models, ppOpts = preProcess,
ctrl = trControl, lev = classLevels, metric = metric,
maximize = maximize, ...)
performance <- tmp$performance
resampleResults <- tmp$resample
}
}
}
trControl$indexExtra <- NULL
if (!(trControl$method %in% c("LOOCV", "oob"))) {
if (modelType == "Classification" && length(grep("^\\cell",
colnames(resampleResults))) > 0) {
resampledCM <- resampleResults[, !(names(resampleResults) %in%
perfNames)]
resampleResults <- resampleResults[, -grep("^\\cell",
colnames(resampleResults))]
}
else resampledCM <- NULL
}
else resampledCM <- NULL
if (trControl$verboseIter) {
cat("Aggregating results\n")
flush.console()
}
perfCols <- names(performance)
perfCols <- perfCols[!(perfCols %in% paramNames)]
if (all(is.na(performance[, metric]))) {
cat(paste("Something is wrong; all the", metric,
"metric values are missing:\n"))
print(summary(performance[, perfCols[!grepl("SD$",
perfCols)], drop = FALSE]))
stop("Stopping", call. = FALSE)
}
if (!is.null(models$sort))
performance <- models$sort(performance)
if (any(is.na(performance[, metric])))
warning("missing values found in aggregated results")
if (trControl$verboseIter && nrow(performance) > 1) {
cat("Selecting tuning parameters\n")
flush.console()
}
selectClass <- class(trControl$selectionFunction)[1]
if (grepl("adapt", trControl$method)) {
perf_check <- subset(performance, .B == max(performance$.B))
}
else perf_check <- performance
if (selectClass == "function") {
bestIter <- trControl$selectionFunction(x = perf_check,
metric = metric, maximize = maximize)
}
else {
if (trControl$selectionFunction == "oneSE") {
bestIter <- oneSE(perf_check, metric, length(trControl$index),
maximize)
}
else {
bestIter <- do.call(trControl$selectionFunction,
list(x = perf_check, metric = metric, maximize = maximize))
}
}
if (is.na(bestIter) || length(bestIter) != 1)
stop("final tuning parameters could not be determined",
call. = FALSE)
if (grepl("adapt", trControl$method)) {
best_perf <- perf_check[bestIter, as.character(models$parameters$parameter),
drop = FALSE]
performance$order <- 1:nrow(performance)
bestIter <- merge(performance, best_perf)$order
performance$order <- NULL
}
bestTune <- performance[bestIter, paramNames, drop = FALSE]
}
else {
bestTune <- tuneGrid
performance <- evalSummaryFunction(y, wts = weights,
ctrl = trControl, lev = classLevels, metric = metric,
method = method)
perfNames <- names(performance)
performance <- as.data.frame(t(performance))
performance <- cbind(performance, tuneGrid)
performance <- performance[-1, , drop = FALSE]
tmp <- resampledCM <- NULL
}
if (!(trControl$method %in% c("LOOCV", "oob", "none"))) {
byResample <- switch(trControl$returnResamp, none = NULL,
all = {
out <- resampleResults
colnames(out) <- gsub("^\\.", "", colnames(out))
out
}, final = {
out <- merge(bestTune, resampleResults)
out <- out[, !(names(out) %in% names(tuneGrid)),
drop = FALSE]
out
})
}
else {
byResample <- NULL
}
orderList <- list()
for (i in seq(along = paramNames)) orderList[[i]] <- performance[,
paramNames[i]]
performance <- performance[do.call("order", orderList), ]
if (trControl$verboseIter) {
bestText <- paste(paste(names(bestTune), "=", format(bestTune,
digits = 3)), collapse = ", ")
if (nrow(performance) == 1)
bestText <- "final model"
cat("Fitting", bestText, "on full training set\n")
flush.console()
}
indexFinal <- if (is.null(trControl$indexFinal))
seq(along = y)
else trControl$indexFinal
if (!(length(trControl$seeds) == 1 && is.na(trControl$seeds)))
set.seed(trControl$seeds[[length(trControl$seeds)]][1])
if (fitFinal) {
finalTime <- system.time(finalModel <- createModel(x = subset_x(x,
indexFinal), y = y[indexFinal], wts = weights[indexFinal],
method = models, tuneValue = bestTune, obsLevels = classLevels,
pp = ppOpt, last = TRUE, classProbs = trControl$classProbs,
sampling = trControl$sampling, ...))
} else {
finalModel <- list(fit = NULL, preProc = NULL)
finalTime <- 0
}
if (trControl$trim && !is.null(models$trim)) {
if (trControl$verboseIter)
old_size <- object.size(finalModel$fit)
finalModel$fit <- models$trim(finalModel$fit)
if (trControl$verboseIter) {
new_size <- object.size(finalModel$fit)
reduction <- format(old_size - new_size, units = "Mb")
if (reduction == "0 Mb")
reduction <- "< 0 Mb"
p_reduction <- (unclass(old_size) - unclass(new_size))/unclass(old_size) *
100
p_reduction <- if (p_reduction < 1)
"< 1%"
else paste0(round(p_reduction, 0), "%")
cat("Final model footprint reduced by", reduction,
"or", p_reduction, "\n")
}
}
pp <- finalModel$preProc
finalModel <- finalModel$fit
if (method == "pls")
finalModel$bestIter <- bestTune
if (method == "glmnet")
finalModel$lambdaOpt <- bestTune$lambda
if (trControl$returnData) {
outData <- if (!is.data.frame(x))
try(as.data.frame(x), silent = TRUE)
else x
if (inherits(outData, "try-error")) {
warning("The training data could not be converted to a data frame for saving")
outData <- NULL
}
else {
outData$.outcome <- y
if (!is.null(weights))
outData$.weights <- weights
}
}
else outData <- NULL
if (trControl$savePredictions == "final")
tmp$predictions <- merge(bestTune, tmp$predictions)
endTime <- proc.time()
times <- list(everything = endTime - startTime, final = finalTime)
out <- structure(list(method = method, modelInfo = models,
modelType = modelType, results = performance, pred = tmp$predictions,
bestTune = bestTune, call = funcCall, dots = list(...),
metric = metric, control = trControl, finalModel = finalModel,
preProcess = pp, trainingData = outData, resample = byResample,
resampledCM = resampledCM, perfNames = perfNames, maximize = maximize,
yLimits = trControl$yLimits, times = times, levels = classLevels),
class = "train")
trControl$yLimits <- NULL
if (trControl$timingSamps > 0) {
pData <- x[sample(1:nrow(x), trControl$timingSamps, replace = TRUE),
, drop = FALSE]
out$times$prediction <- system.time(predict(out, pData))
}
else out$times$prediction <- rep(NA, 3)
out
}
, что дает
data(iris)
TrainData <- iris[,1:4]
TrainClasses <- iris[,5]
knnFit1 <- train(TrainData, TrainClasses,
method = "knn",
preProcess = c("center", "scale"),
tuneLength = 10,
trControl = trainControl(method = "cv"), fitFinal = FALSE)
knnFit1$finalModel
# NULL