Run a cross-validation
cv.Rdcv() executes a cross-validation procedure.
For each fold (specified in argument nfold or folds), the original model is re-fitted using the complement
of the fold as training data.
Cross-validations of multiple models are executed using identical folds.
Usage
cv(x, ...)
# S3 method for model
cv(
x,
nfold = getOption("cv_nfold"),
folds = NULL,
...,
metric = NULL,
iter = getOption("cv_iter"),
param = TRUE,
keep_fits = FALSE,
verbose = getOption("cv_verbose")
)
# S3 method for multimodel
cv(
x,
nfold = getOption("cv_nfold"),
folds = NULL,
metric = NULL,
iter = getOption("cv_iter"),
param = TRUE,
keep_fits = FALSE,
verbose = getOption("cv_verbose"),
...
)
# S3 method for default
cv(
x,
nfold = getOption("cv_nfold"),
folds = NULL,
...,
metric = NULL,
iter = getOption("cv_iter"),
param = TRUE,
keep_fits = FALSE
)
# S3 method for cv
print(
x,
what = c("class", "formula", "weights"),
show_metric = TRUE,
abbreviate = TRUE,
n = getOption("print_max_model"),
width = getOption("width"),
param = TRUE,
...
)Arguments
- x
A
model,multimodelor fitted model (see sections “Methods”).- ...
These arguments are passed internally to methods of
cv_simple(), a currently undocumented generic that runs the cross-validation on a single model.- nfold, folds
Passed to
make_folds.- metric
A metric (see
metrics).metric=NULLselects the default metric, seedefault_metric.- iter
A preference criterion, or a list of several criteria. Only relevant for iteratively fitted models (see ifm), ignored otherwise.
- param
Logical. Include parameter table in output? See
?multimodel.- keep_fits
Logical: Keep the cross-validation's model fits?
- verbose
Logical: Output information on execution progress in console?
- what
Which elements of the multimodel should be printed? See
print.model.- show_metric
Logical: Whether to print the cross-validated models' metric.
- abbreviate
Logical. If
TRUE(the default), long formulas and calls are printed in abbreviated mode, such that they usually fit on 4 or fewer output lines; otherwise they are printed entirely, no matter how long they are.- n
Integer: Model details are printed for first
nmodels inprint.cv().- width
Integer: Width of printed output.
Value
The output from cv() is a list of class “cv” having the following elements:
multimodel: a
multimodel;folds: the folds, as defined in
nfoldorfolds(seemake_folds);fits: if
keep_fits=FALSE(the default):NULL; ifkeep_fits=TRUE: the list of the model fits resulting from the cross-validation, seeextract_fits);metric: a list: the default evaluation metrics, not necessarily the same for all models;
predictions: a list of matrices of dimension \(n \times k\) where \(n\) is the number of observations in the model data and \(k\) is the number of folds; each of these list entries corresponds to a model;
performance: a list of performance tables (see
cv_performance), that are saved only for certain model classes; oftenNULL;timing: execution time of cross-validation;
extras: a list of extra results from cross-validation, which are saved only for certain model classes; often
NULL. If the modelxis an iteratively fitted model (ifm),extrascontain the cross-validated model's evaluation log and information on preferred iterations.
Details
The same cross-validations groups (folds) are used for all models.
Each model in x is processed separately with the function cv_simple(),
a generic function for internal use.
Besides the standard method cv_simple.model(), there are currently specific methods of cv_simple()
for models generated with fm_xgb() and fm_glmnet().
Methods
cv.multimodel(), the core method.cv.model(x, ...)corresponds tox %>% multimodel %>% cv(...).The default method essentially executes
x %>% model %>% cv(...)and thus expects a fitted model as itsx.
Examples
mm <- multimodel(model(fm_knn(Sepal.Length ~ ., iris)), k = 1:5)
cv(mm)
#> --- A “cv” object containing 5 validated models ---
#>
#> Validation procedure: Complete k-fold Cross-Validation
#> Number of obs in data: 150
#> Number of test sets: 10
#> Size of test sets: 15
#> Size of training sets: 135
#>
#> Models:
#>
#> ‘model1’:
#> model class: fm_knn
#> formula: Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width +
#> Species
#> metric: rmse
#>
#> ‘model2’:
#> model class: fm_knn
#> formula: Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width +
#> Species
#> metric: rmse
#>
#> ‘model3’:
#> model class: fm_knn
#> formula: Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width +
#> Species
#> metric: rmse
#>
#> and 2 models more, labelled:
#> ‘model4’, ‘model5’
#>
#>
#> Parameter table:
#> k
#> model1 1
#> model2 2
#> model3 3
#> ... 2 rows omitted (nrow=5)
mm_cars <- c(simpleLinear = model(lm(mpg ~ cyl, mtcars)),
linear = model(lm(mpg ~ ., mtcars)),
if (require(ranger)) model(ranger(mpg ~ ., mtcars), label = "forest"))
#> Loading required package: ranger
mm_cars
#> --- A “multimodel” object containing 3 models ---
#>
#> ‘simpleLinear’:
#> model class: lm
#> formula: mpg ~ cyl
#> data: data.frame [32 x 11],
#> input as: ‘data = mtcars’
#> call: lm(formula = mpg ~ cyl, data = data)
#>
#> ‘linear’:
#> model class: lm
#> formula: mpg ~ cyl + disp + hp + drat + wt + qsec + vs +
#> am + gear + carb
#> data: data.frame [32 x 11],
#> input as: ‘data = mtcars’
#> call: lm(formula = mpg ~ ., data = data)
#>
#> ‘forest’:
#> model class: ranger
#> formula: mpg ~ cyl + disp + hp + drat + wt + qsec + vs +
#> am + gear + carb
#> data: data.frame [32 x 11],
#> input as: ‘data = mtcars’
#> call: ranger(formula = mpg ~ ., data = data)
cv_cars <- cv(mm_cars, nfold = 5)
cv_cars
#> --- A “cv” object containing 3 validated models ---
#>
#> Validation procedure: Complete k-fold Cross-Validation
#> Number of obs in data: 32
#> Number of test sets: 5
#> Size of test sets: ~6
#> Size of training sets: ~26
#>
#> Models:
#>
#> ‘simpleLinear’:
#> model class: lm
#> formula: mpg ~ cyl
#> metric: rmse
#>
#> ‘linear’:
#> model class: lm
#> formula: mpg ~ cyl + disp + hp + drat + wt + qsec + vs +
#> am + gear + carb
#> metric: rmse
#>
#> ‘forest’:
#> model class: ranger
#> formula: mpg ~ cyl + disp + hp + drat + wt + qsec + vs +
#> am + gear + carb
#> metric: rmse
cv_performance(cv_cars)
#> --- Performance table ---
#> Metric: rmse
#> train_rmse test_rmse time_cv
#> simpleLinear 3.0495 3.5248 0.005
#> linear 2.0169 3.7303 0.009
#> forest 1.3057 2.4861 0.046
# Non-default metric:
cv_performance(cv_cars, metric = "medae")
#> --- Performance table ---
#> Metric: medae
#> train_medae test_medae time_cv
#> simpleLinear 1.7604 2.3055 0.005
#> linear 1.4328 2.3144 0.009
#> forest 0.9158 2.0434 0.046