diff --git a/DESCRIPTION b/DESCRIPTION index 04e16da1..7cb6bd80 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: tune Title: Tidy Tuning Tools -Version: 1.3.0.9000 +Version: 1.3.0.9001 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-2402-136X")), diff --git a/NAMESPACE b/NAMESPACE index 7b1e69c8..4f58e16e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -71,6 +71,7 @@ S3method(outcome_names,recipe) S3method(outcome_names,terms) S3method(outcome_names,tune_results) S3method(outcome_names,workflow) +S3method(outcome_names,workflow_variables) S3method(parameters,model_spec) S3method(parameters,recipe) S3method(parameters,workflow) diff --git a/NEWS.md b/NEWS.md index 2cfe791c..b8367510 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,8 @@ * Post-processing: new `schedule_grid()` for scheduling a grid including post-processing (#988). +* Added a new method for `outcome_names()` for workflows that use `add_variables()` (#993). + # tune 1.3.0 * The package will now warn when parallel processing has been enabled with foreach but not with future. See [`?parallelism`](https://tune.tidymodels.org/dev/reference/parallelism.html) to learn more about transitioning your code to future (#878, #866). The next version of tune will move to a pure future implementation. diff --git a/R/outcome-names.R b/R/outcome-names.R index ad1e761e..2a9abe60 100644 --- a/R/outcome-names.R +++ b/R/outcome-names.R @@ -1,7 +1,9 @@ #' Determine names of the outcome data in a workflow #' #' @param x An object. -#' @param ... Not used. +#' @param ... Further arguments passed to or from other methods (such as `data`). +#' @param data The training set data (if needed). +#' @param call The call to be displayed in warnings or errors. #' @return A character string of variable names #' @keywords internal #' @examples @@ -39,20 +41,39 @@ outcome_names.recipe <- function(x, ...) { #' @export #' @rdname outcome_names -outcome_names.workflow <- function(x, ...) { - if (!is.null(x$fit$fit)) { +outcome_names.workflow <- function(x, ..., call = caller_env()) { + if (!is.null(x$pre$mold)) { y_vals <- extract_mold(x)$outcomes res <- colnames(y_vals) } else { preprocessor <- extract_preprocessor(x) - res <- outcome_names(preprocessor) + res <- outcome_names(preprocessor, ..., call = call) } res } #' @export #' @rdname outcome_names -outcome_names.tune_results <- function(x, ...) { +outcome_names.workflow_variables <- function( + x, + data = NULL, + ..., + call = caller_env() +) { + if (is.null(data)) { + cli::cli_abort( + "To determine the outcome names when {.fn add_variables} is used, please + pass the training set data to the {.arg data} argument.", + call = call + ) + } + res <- rlang::eval_tidy(x$outcomes, data, env = call) + res +} + +#' @export +#' @rdname outcome_names +outcome_names.tune_results <- function(x, ..., call = caller_env()) { att <- attributes(x) if (any(names(att) == "outcomes")) { res <- att$outcomes diff --git a/man/outcome_names.Rd b/man/outcome_names.Rd index 1066d9aa..2639096b 100644 --- a/man/outcome_names.Rd +++ b/man/outcome_names.Rd @@ -6,6 +6,7 @@ \alias{outcome_names.formula} \alias{outcome_names.recipe} \alias{outcome_names.workflow} +\alias{outcome_names.workflow_variables} \alias{outcome_names.tune_results} \title{Determine names of the outcome data in a workflow} \usage{ @@ -17,14 +18,18 @@ outcome_names(x, ...) \method{outcome_names}{recipe}(x, ...) -\method{outcome_names}{workflow}(x, ...) +\method{outcome_names}{workflow}(x, ..., call = caller_env()) -\method{outcome_names}{tune_results}(x, ...) +\method{outcome_names}{workflow_variables}(x, data = NULL, ..., call = caller_env()) + +\method{outcome_names}{tune_results}(x, ..., call = caller_env()) } \arguments{ \item{x}{An object.} -\item{...}{Not used.} +\item{...}{Further arguments passed to or from other methods (such as \code{data}).} + +\item{data}{The training set data (if needed).} } \value{ A character string of variable names diff --git a/tests/testthat/_snaps/outcome-names.md b/tests/testthat/_snaps/outcome-names.md new file mode 100644 index 00000000..f9eba55d --- /dev/null +++ b/tests/testthat/_snaps/outcome-names.md @@ -0,0 +1,8 @@ +# workflows + variables + + Code + outcome_names(wflow_1) + Condition + Error: + ! To determine the outcome names when `add_variables()` is used, please pass the training set data to the `data` argument. + diff --git a/tests/testthat/test-outcome-names.R b/tests/testthat/test-outcome-names.R index d1c66f36..029009c2 100644 --- a/tests/testthat/test-outcome-names.R +++ b/tests/testthat/test-outcome-names.R @@ -95,6 +95,19 @@ test_that("workflows + formulas", { expect_equal(outcome_names(parsnip::fit(wflow_2, mtcars)), c("mpg", "wt")) }) +## ----------------------------------------------------------------------------- + +test_that("workflows + variables", { + lm_mod <- parsnip::linear_reg() %>% parsnip::set_engine("lm") + wflow <- workflow() %>% add_model(lm_mod) + + wflow_1 <- wflow %>% add_variables(outcomes = "mpg", predictors = c(wt)) + fit_1 <- fit(wflow_1, mtcars) + + expect_snapshot(outcome_names(wflow_1), error = TRUE) + expect_equal(outcome_names(wflow_1, mtcars), "mpg") + expect_equal(outcome_names(fit_1, mtcars), "mpg") +}) ## -----------------------------------------------------------------------------