Skip to content

Commit 539e5a7

Browse files
author
‘topepo’
committed
quantile -> quantile_levels for #1203
1 parent bef131b commit 539e5a7

File tree

6 files changed

+28
-10
lines changed

6 files changed

+28
-10
lines changed

NEWS.md

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).
1313

14+
## Breaking Change
15+
16+
* For quantile prediction, the `predict()` argument has been changed from `quantile` to `quantile_levels` for consistency. This does not affect models with mode `"quantile regression"`.
1417

1518
# parsnip 1.2.1
1619

R/predict.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())
344344

345345
# ----------------------------------------------------------------------------
346346

347-
other_args <- c("interval", "level", "std_error", "quantile",
347+
other_args <- c("interval", "level", "std_error", "quantile_levels",
348348
"time", "eval_time", "increasing")
349349
is_pred_arg <- names(the_dots) %in% other_args
350350
if (any(!is_pred_arg)) {

R/predict_quantile.R

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
#' @keywords internal
22
#' @rdname other_predict
3-
#' @param quantile A vector of numbers between 0 and 1 for the quantile being
4-
#' predicted.
3+
#' @param quantile_levels A vector of values between zero and one.
54
#' @inheritParams predict.model_fit
65
#' @method predict_quantile model_fit
76
#' @export predict_quantile.model_fit
87
#' @export
98
predict_quantile.model_fit <- function(object,
109
new_data,
11-
quantile = (1:9)/10,
10+
quantile_levels = NULL,
1211
interval = "none",
1312
level = 0.95,
1413
...) {
@@ -20,15 +19,27 @@ predict_quantile.model_fit <- function(object,
2019
return(NULL)
2120
}
2221

22+
if (object$spec$mode != "quantile regression") {
23+
if (is.null(quantile_levels)) {
24+
quantile_levels <- (1:9)/10
25+
}
26+
hardhat::check_quantile_levels(quantile_levels)
27+
# Pass some extra arguments to be used in post-processor
28+
object$quantile_levels <- quantile_levels
29+
} else {
30+
if (!is.null(quantile_levels)) {
31+
cli::cli_abort("{.arg quantile_levels} are specified by {.fn set_mode}
32+
when the mode is {.val quantile regression}.")
33+
}
34+
}
35+
2336
new_data <- prepare_data(object, new_data)
2437

2538
# preprocess data
2639
if (!is.null(object$spec$method$pred$quantile$pre)) {
2740
new_data <- object$spec$method$pred$quantile$pre(new_data, object)
2841
}
2942

30-
# Pass some extra arguments to be used in post-processor
31-
object$spec$method$pred$quantile$args$p <- quantile
3243
pred_call <- make_pred_call(object$spec$method$pred$quantile)
3344

3445
res <- eval_tidy(pred_call)

man/other_predict.Rd

+2-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/set_args.Rd

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-linear_reg_quantreg.R

+5
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ test_that('linear quantile regression via quantreg - multiple quantiles', {
8383
expect_named(ten_quant_df, c(".pred_quantile", ".quantile_levels", ".row"))
8484
expect_true(nrow(ten_quant_df) == nrow(sac_test) * 10)
8585

86+
expect_snapshot(
87+
ten_quant_pred <- predict(ten_quant, new_data = sac_test),
88+
error = TRUE
89+
)
90+
8691
###
8792

8893
ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1,])

0 commit comments

Comments
 (0)