`fit_caret`, `fit_tidymodels`, and `fit_h2o` are wrappers for fitting a model using caret, tidymodels, and h2o backends, respectively. These wrapper functions provide uniformity of input arguments to easily switch between the different modeling packages (see `fit_models()`).

fit_caret(
  Xtrain,
  ytrain,
  model_name,
  model_options = list(),
  cv_options = list(),
  train_options = list()
)

fit_h2o(
  Xtrain,
  ytrain,
  model_name,
  model_options = list(),
  cv_options = list(),
  train_options = list()
)

fit_tidymodels(
  Xtrain,
  ytrain,
  model_name,
  model_options = list(),
  cv_options = list(),
  train_options = list()
)

Arguments

Xtrain

Training data matrix or data frame.

ytrain

Training response vector.

model_name

Name of model to fit. See caret, h2o, or tidymodels for a list of available models.

model_options

List of named arguments to input into the model as arguments. See details below.

cv_options

List of cross-validation options to use for tuning hyperparameters. Possible options are `nfolds` (default is 10), `foldids`, and `metric`. `nfolds` gives the number of folds in the cross-validation scheme. `foldids` is a list with elements for each cross-validation fold, where each list element is a vector of integers corresponding to the rows used for training in that fold. `metric` is a string that specifies which metric to use to select the best hyperparameters. See details below.

train_options

List of additional training control options. See details below.

Value

The trained model fit. Specifically, `fit_caret()` returns a trained model fit of class `train` (see output of `train()`). `fit_tidymodels()` returns a trained model fit of class `workflow` (see output of `fit-workflow`) if hyperparameter tuning is not needed. If hyperparameter tuning is performed, then `fit_tidymodels()` returns the trained CV model fit of class `tune_results` (see output of `tune::tune_grid()`) with the additional attribute "best_fit" that holds the trained finalized `workflow` fit using the best hyperparameters. `fit_h2o` returns a trained h2o model fit (see output of `h2o.[model_name]`) if hyperparameter tuning is not needed. If hyperparameter tuning is performed, then `fit_h2o()` returns the trained CV model fit (see output of `h2o.grid()`) with the additional attribute "best_fit" that holds the trained finalized h2o model fit using the best hyperparmeters.

Details

To specify a set of hyperparameters to tune in the model, add an element in the `model_options` list named `.tune_params` with the list of named hyperparameters to tune. `fit_caret()` and `fit_tidymodels()` can also take in a data frame of hyperparameters to tune. In `fit_caret()`, `model_options$.tune_params` is passed to the `tuneGrid` argument in `train()`. In `fit_tidymodels()`, `model_options$.tune_params` is passed to the `grid` argument in `tune::tune_grid()`. In `fit_h2o()`, `model_options$.tune_params` is passed to the `hyper_params` argument in `h2o.grid()`.

For `fit_caret()`, `train_options` should be a list of named arguments to pass to `trainControl()`, and `model_options` should be a list of named arguments to pass to `train()`. Further, see the `metric` argument in `train()` for possible options to set for `cv_options$metric`.

For `fit_tidymodels()`, `train_options` should be a list of named arguments to pass to `tune::tune_grid()`, which is only used if hyperparameter tuning is needed. `model_options` should be a list of named arguments to pass to the `parsnip` model function given by `model_name` (e.g., `parsnip::rand_forest()`). Note that if additional arguments need to be set in the `engine`, then `model_options$engine` can be list with these arguments (e.g., `model_options$engine = list(engine = "ranger", importance = "impurity")`). Further, see the `metric` argument in `tune::select_best()` for possible options to set for `cv_options$metric`.

For `fit_h2o()`, `train_options` should be a list of named arguments to pass to `h2o.grid()`, which is only used if hyperparameter tuning is needed. `model_options` should be a list of named arguments to pass to `h2o.grid()` if tuning is needed or to the `h2o` model function given by `model_name` (e.g., `h2o.randomForest()`). Further, see the `sort_by` argument in `h2o.getGrid()` for possible options to set for `cv_options$metric`. Note that the `decreasing` argument in `h2o.getGrid()` is set to `FALSE` unless `cv_options$metric` is one of "auc", "accuracy", "precision", "recall", or "f1".

See also

Other fit_models_family: fit_models()