Skip to content

Commit a153df0

Browse files
author
‘topepo’
committed
enable quantile prediction
1 parent abcd97d commit a153df0

File tree

6 files changed

+55
-46
lines changed

6 files changed

+55
-46
lines changed

R/quantiles.R renamed to R/aaa_quantiles.R

+19-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ check_quantile_level <- function(x, object, call) {
99
{.arg quantile_level} must be specified for quantile regression models.")
1010
}
1111
}
12+
x <- sort(unique(x))
1213
# TODO we need better vectorization here, otherwise we get things like:
1314
# "Error during wrapup: i In index: 2." in the traceback.
1415
res <-
@@ -17,7 +18,24 @@ check_quantile_level <- function(x, object, call) {
1718
arg = "quantile_level", call = call,
1819
allow_infinite = FALSE)
1920
)
20-
return(invisible(TRUE))
21+
x
2122
}
2223

24+
# Assumes the columns have the same order as quantile_level
25+
restructure_rq_pred <- function(x, object) {
26+
n <- nrow(x)
27+
p <- ncol(x)
28+
# TODO check p = length(quantile_level)
29+
# check p = 1 case
30+
quantile_level <- object$spec$quantile_level
31+
res <-
32+
tibble::tibble(
33+
.pred_quantile = as.vector(x),
34+
.quantile_level = rep(quantile_level, each = n),
35+
.row = rep(1:n, p))
36+
res <- vctrs::vec_split(x = res[,1:2], by = res[, ".row"])
37+
res <- vctrs::vec_cbind(res$key, tibble::new_tibble(list(.pred_quantile = res$val)))
38+
res$.row <- NULL
39+
res
40+
}
2341

R/arguments.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ set_mode.model_spec <- function(object, mode, quantile_level = NULL, ...) {
109109

110110
object$mode <- mode
111111
object$user_specified_mode <- TRUE
112-
check_quantile_level(quantile_level, object, call = caller_env(0))
112+
quantile_level <-
113+
check_quantile_level(quantile_level, object, call = caller_env(0))
113114
object$quantile_level <- quantile_level
114115
object
115116
}

R/fit.R

+8
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ fit.model_spec <-
169169
eval_env$formula <- formula
170170
eval_env$weights <- wts
171171

172+
if ( !is.null(object$quantile_level) ) {
173+
eval_env$quantile_level <- object$quantile_level
174+
}
175+
172176
fit_interface <-
173177
check_interface(eval_env$formula, eval_env$data, cl, object)
174178

@@ -282,6 +286,10 @@ fit_xy.model_spec <-
282286
eval_env$y_var <- y_var
283287
eval_env$weights <- weights_to_numeric(case_weights, object)
284288

289+
if ( !is.null(object$quantile_level) ) {
290+
eval_env$quantile_level <- object$quantile_level
291+
}
292+
285293
# TODO case weights: pass in eval_env not individual elements
286294
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)
287295

R/linear_reg_data.R

+4-11
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ set_fit(
596596
interface = "formula",
597597
protect = c("formula", "data", "weights"),
598598
func = c(pkg = "quantreg", fun = "rq"),
599-
defaults = list()
599+
defaults = list(tau = expr(quantile_level))
600600
)
601601
)
602602

@@ -635,22 +635,15 @@ set_pred(
635635
model = "linear_reg",
636636
eng = "quantreg",
637637
mode = "quantile regression",
638-
type = "conf_int",
638+
type = "quantile",
639639
value = list(
640640
pre = NULL,
641-
post = function(results, object) {
642-
tibble::as_tibble(results) %>%
643-
dplyr::select(-fit) %>%
644-
setNames(c(".pred_lower", ".pred_upper"))
645-
},
641+
post = restructure_rq_pred,
646642
func = c(fun = "predict"),
647643
args =
648644
list(
649645
object = expr(object$fit),
650-
newdata = expr(new_data),
651-
interval = "confidence",
652-
level = expr(level)
646+
newdata = expr(new_data)
653647
)
654648
)
655649
)
656-

R/predict_quantile.R

+21-25
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,36 @@
66
#' @method predict_quantile model_fit
77
#' @export predict_quantile.model_fit
88
#' @export
9-
predict_quantile.model_fit <- function(object,
10-
new_data,
11-
quantile = (1:9)/10,
12-
interval = "none",
13-
level = 0.95,
14-
...) {
9+
predict_quantile.model_fit <- function(object, new_data, ...) {
1510

16-
check_spec_pred_type(object, "quantile")
11+
check_spec_pred_type(object, "quantile")
1712

18-
if (inherits(object$fit, "try-error")) {
19-
rlang::warn("Model fit failed; cannot make predictions.")
20-
return(NULL)
21-
}
22-
23-
new_data <- prepare_data(object, new_data)
13+
if (inherits(object$fit, "try-error")) {
14+
cli::cli_warn("Model fit failed; cannot make predictions.")
15+
return(NULL)
16+
}
2417

25-
# preprocess data
26-
if (!is.null(object$spec$method$pred$quantile$pre))
27-
new_data <- object$spec$method$pred$quantile$pre(new_data, object)
18+
new_data <- prepare_data(object, new_data)
2819

29-
# Pass some extra arguments to be used in post-processor
30-
object$spec$method$pred$quantile$args$p <- quantile
31-
pred_call <- make_pred_call(object$spec$method$pred$quantile)
20+
# preprocess data
21+
if (!is.null(object$spec$method$pred$quantile$pre)) {
22+
new_data <- object$spec$method$pred$quantile$pre(new_data, object)
23+
}
3224

33-
res <- eval_tidy(pred_call)
25+
# Pass some extra arguments to be used in post-processor
26+
object$spec$method$pred$quantile$args$quantile_level <- object$quantile_level
27+
pred_call <- make_pred_call(object$spec$method$pred$quantile)
3428

35-
# post-process the predictions
36-
if(!is.null(object$spec$method$pred$quantile$post)) {
37-
res <- object$spec$method$pred$quantile$post(res, object)
38-
}
29+
res <- eval_tidy(pred_call)
3930

40-
res
31+
# post-process the predictions
32+
if(!is.null(object$spec$method$pred$quantile$post)) {
33+
res <- object$spec$method$pred$quantile$post(res, object)
4134
}
4235

36+
res
37+
}
38+
4339
# @export
4440
# @keywords internal
4541
# @rdname other_predict

man/other_predict.Rd

+1-8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)