Skip to contents

Plot ROC/PR curves for feature selection or some summary thereof across experimental replicates.

Usage

plot_feature_selection_curve(
  fit_results,
  eval_results = NULL,
  evaluator_name = NULL,
  vary_params = NULL,
  curve = c("ROC", "PR"),
  show = c("line", "ribbon"),
  ...
)

Arguments

fit_results

A tibble, as returned by the fit method.

eval_results

A list of result tibbles, as returned by the evaluate method.

evaluator_name

Name of Evaluator containing results to plot. To compute the evaluation summary results from scratch or if the evaluation summary results have not yet been evaluated, set to NULL.

vary_params

A vector of parameter names that are varied across in the Experiment.

curve

Either "ROC" or "PR" indicating whether to plot the ROC or Precision-Recall curve.

show

Character vector with elements being one of "boxplot", "point", "line", "bar", "errorbar", "ribbon" indicating what plot layer(s) to construct.

...

Additional arguments to pass to plot_eval_summary(). This includes arguments for plotting and for passing into summarize_feature_selection_curve().

Value

If interactive = TRUE, returns a plotly object if plot_by is NULL and a list of plotly objects if plot_by is not NULL. If interactive = FALSE, returns a ggplot object if plot_by is NULL and a list of ggplot objects if plot_by is not NULL.

Examples

# generate example fit_results data
fit_results <- tibble::tibble(
  .rep = rep(1:2, times = 2),
  .dgp_name = c("DGP1", "DGP1", "DGP2", "DGP2"),
  .method_name = c("Method"),
  feature_info = lapply(
    1:4,
    FUN = function(i) {
      tibble::tibble(
        # feature names
        feature = c("featureA", "featureB", "featureC"),  
        # true feature support
        true_support = c(TRUE, FALSE, TRUE),  
        # estimated feature support
        est_support = c(TRUE, FALSE, FALSE),  
        # estimated feature importance scores
        est_importance = c(10, runif(2, min = -2, max = 2))  
      )
    }
  )
)

# generate example eval_results data
eval_results <- list(
  ROC = summarize_feature_selection_curve(
    fit_results, 
    curve = "ROC",
    nested_data = "feature_info",
    truth_col = "true_support", 
    imp_col = "est_importance"
  ),
  PR = summarize_feature_selection_curve(
    fit_results, 
    curve = "PR",
    nested_data = "feature_info",
    truth_col = "true_support", 
    imp_col = "est_importance"
  )
)

# create summary ROC/PR plots using pre-computed evaluation results
roc_plt <- plot_feature_selection_curve(fit_results = fit_results, 
                                        eval_results = eval_results,
                                        evaluator_name = "ROC", curve = "ROC",
                                        show = c("line", "ribbon"))
pr_plt <- plot_feature_selection_curve(fit_results = fit_results, 
                                       eval_results = eval_results,
                                       evaluator_name = "PR", curve = "PR",
                                       show = c("line", "ribbon"))
# or alternatively, create the same plots without pre-computing evaluation results
roc_plt <- plot_feature_selection_curve(fit_results, show = c("line", "ribbon"),
                                        nested_data = "feature_info",
                                        truth_col = "true_support",
                                        imp_col = "est_importance",
                                        curve = "ROC")
pr_plt <- plot_feature_selection_curve(fit_results, show = c("line", "ribbon"),
                                       nested_data = "feature_info",
                                       truth_col = "true_support",
                                       imp_col = "est_importance",
                                       curve = "PR")

# can customize plot (see plot_eval_summary() for possible arguments)
roc_plt <- plot_feature_selection_curve(fit_results = fit_results, 
                                        eval_results = eval_results,
                                        evaluator_name = "ROC", curve = "ROC",
                                        show = c("line", "ribbon"),
                                        plot_by = ".dgp_name")