Skip to content

Commit bef131b

Browse files
topepo‘topepo’hfrick
authored
quantile regression updates for new hardhat model (#1207)
* bump hardhat version * remove parts now in hardhat * update for new hardhat version * quantile_levels (plural now) * news update * typo * rename helper function * run CI on PRs from branches * forgotten remote * actions for edited PRs * plural * expand branch list * export function for censored to use * updated snapshot * remake snapshot * Revert "remake snapshot" This reverts commit 954e326. * updated snapshot * Update R/arguments.R Co-authored-by: Hannah Frick <[email protected]> * typo * changes from reviewer feedback --------- Co-authored-by: ‘topepo’ <‘[email protected]’> Co-authored-by: Hannah Frick <[email protected]>
1 parent 3bdb471 commit bef131b

22 files changed

+96
-560
lines changed

.github/workflows/R-CMD-check.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88
push:
99
branches: [main, master]
1010
pull_request:
11-
branches: [main, master]
11+
branches: [main, master, quantile-mode]
1212
workflow_dispatch:
1313

1414
name: R-CMD-check.yaml

.github/workflows/test-coverage.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
push:
55
branches: [main, master]
66
pull_request:
7-
branches: [main, master]
7+
branches: [main, master, quantile-mode]
88

99
name: test-coverage.yaml
1010

DESCRIPTION

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Imports:
2525
ggplot2,
2626
globals,
2727
glue,
28-
hardhat (>= 1.4.0),
28+
hardhat (>= 1.4.0.9002),
2929
lifecycle,
3030
magrittr,
3131
pillar,
@@ -68,6 +68,8 @@ Suggests:
6868
VignetteBuilder:
6969
knitr
7070
ByteCompile: true
71+
Remotes:
72+
tidymodels/hardhat
7173
Config/Needs/website: C50, dbarts, earth, glmnet, keras, kernlab, kknn,
7274
LiblineaR, mgcv, nnet, parsnip, randomForest, ranger, rpart, rstanarm,
7375
tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate,

NAMESPACE

+1-15
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
S3method(.censoring_weights_graf,default)
44
S3method(.censoring_weights_graf,model_fit)
5-
S3method(as.matrix,quantile_pred)
6-
S3method(as_tibble,quantile_pred)
75
S3method(augment,model_fit)
86
S3method(autoplot,glmnet)
97
S3method(autoplot,model_fit)
@@ -38,12 +36,10 @@ S3method(extract_spec_parsnip,model_fit)
3836
S3method(fit,model_spec)
3937
S3method(fit_xy,gen_additive_mod)
4038
S3method(fit_xy,model_spec)
41-
S3method(format,quantile_pred)
4239
S3method(glance,model_fit)
4340
S3method(has_multi_predict,default)
4441
S3method(has_multi_predict,model_fit)
4542
S3method(has_multi_predict,workflow)
46-
S3method(median,quantile_pred)
4743
S3method(multi_predict,"_C5.0")
4844
S3method(multi_predict,"_earth")
4945
S3method(multi_predict,"_elnet")
@@ -58,7 +54,6 @@ S3method(multi_predict_args,default)
5854
S3method(multi_predict_args,model_fit)
5955
S3method(multi_predict_args,workflow)
6056
S3method(nullmodel,default)
61-
S3method(obj_print_footer,quantile_pred)
6257
S3method(predict,"_elnet")
6358
S3method(predict,"_glmnetfit")
6459
S3method(predict,"_lognet")
@@ -177,8 +172,6 @@ S3method(update,svm_rbf)
177172
S3method(varying_args,model_spec)
178173
S3method(varying_args,recipe)
179174
S3method(varying_args,step)
180-
S3method(vec_ptype_abbr,quantile_pred)
181-
S3method(vec_ptype_full,quantile_pred)
182175
export("%>%")
183176
export(.censoring_weights_graf)
184177
export(.check_glmnet_penalty_fit)
@@ -233,7 +226,6 @@ export(extract_fit_engine)
233226
export(extract_fit_time)
234227
export(extract_parameter_dials)
235228
export(extract_parameter_set_dials)
236-
export(extract_quantile_levels)
237229
export(extract_spec_parsnip)
238230
export(find_engine_files)
239231
export(fit)
@@ -272,6 +264,7 @@ export(make_classes)
272264
export(make_engine_list)
273265
export(make_seealso_list)
274266
export(mars)
267+
export(matrix_to_quantile_pred)
275268
export(max_mtry_formula)
276269
export(maybe_data_frame)
277270
export(maybe_matrix)
@@ -288,7 +281,6 @@ export(new_model_spec)
288281
export(null_model)
289282
export(null_value)
290283
export(nullmodel)
291-
export(obj_print_footer)
292284
export(parsnip_addin)
293285
export(pls)
294286
export(poisson_reg)
@@ -316,7 +308,6 @@ export(prepare_data)
316308
export(print_model_spec)
317309
export(prompt_missing_implementation)
318310
export(proportional_hazards)
319-
export(quantile_pred)
320311
export(rand_forest)
321312
export(repair_call)
322313
export(req_pkgs)
@@ -360,8 +351,6 @@ export(update_model_info_file)
360351
export(update_spec)
361352
export(varying)
362353
export(varying_args)
363-
export(vec_ptype_abbr)
364-
export(vec_ptype_full)
365354
export(xgb_predict)
366355
export(xgb_train)
367356
import(rlang)
@@ -439,8 +428,5 @@ importFrom(utils,globalVariables)
439428
importFrom(utils,head)
440429
importFrom(utils,methods)
441430
importFrom(utils,stack)
442-
importFrom(vctrs,obj_print_footer)
443-
importFrom(vctrs,vec_ptype_abbr)
444-
importFrom(vctrs,vec_ptype_full)
445431
importFrom(vctrs,vec_size)
446432
importFrom(vctrs,vec_unique)

NEWS.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# parsnip (development version)
22

33
* A new model mode (`"quantile regression"`) was added. Including:
4-
* A function to create a new vector class called `quantile_pred()` was added (#1191).
54
* A `linear_reg()` engine for `"quantreg"`.
5+
* Predictions are encoded via a custom vector type. See [hardhat::quantile_pred()].
6+
* Predicted quantile levels are designated when the new mode is specified. See `?set_mode`.
67

78
* `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775).
89

R/aaa_quantiles.R

+8-213
Original file line numberDiff line numberDiff line change
@@ -1,222 +1,17 @@
1-
# Helpers for quantile regression models
2-
3-
check_quantile_level <- function(x, object, call) {
4-
if (object$mode != "quantile regression") {
5-
return(invisible(TRUE))
6-
} else {
7-
if (is.null(x)) {
8-
cli::cli_abort("In {.fn check_mode}, at least one value of
9-
{.arg quantile_level} must be specified for quantile regression models.")
10-
}
11-
}
12-
if (any(is.na(x))) {
13-
cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.",
14-
call = call)
15-
}
16-
x <- sort(unique(x))
17-
check_vector_probability(x, arg = "quantile_level", call = call)
18-
x
19-
}
20-
21-
22-
# -------------------------------------------------------------------------
23-
# A column vector of quantiles with an attribute
24-
25-
#' @importFrom vctrs vec_ptype_abbr
26-
#' @export
27-
vctrs::vec_ptype_abbr
28-
29-
#' @importFrom vctrs vec_ptype_full
30-
#' @export
31-
vctrs::vec_ptype_full
32-
33-
34-
#' @export
35-
vec_ptype_abbr.quantile_pred <- function(x, ...) {
36-
n_lvls <- length(attr(x, "quantile_levels"))
37-
cli::format_inline("qtl{?s}({n_lvls})")
38-
}
39-
40-
#' @export
41-
vec_ptype_full.quantile_pred <- function(x, ...) "quantiles"
42-
43-
new_quantile_pred <- function(values = list(), quantile_levels = double()) {
44-
quantile_levels <- vctrs::vec_cast(quantile_levels, double())
45-
vctrs::new_vctr(
46-
values, quantile_levels = quantile_levels, class = "quantile_pred"
47-
)
48-
}
49-
50-
#' Create a vector containing sets of quantiles
51-
#'
52-
#' [quantile_pred()] is a special vector class used to efficiently store
53-
#' predictions from a quantile regression model. It requires the same quantile
54-
#' levels for each row being predicted.
55-
#'
56-
#' @param values A matrix of values. Each column should correspond to one of
57-
#' the quantile levels.
58-
#' @param quantile_levels A vector of probabilities corresponding to `values`.
59-
#' @param x An object produced by [quantile_pred()].
60-
#' @param .rows,.name_repair,rownames Arguments not used but required by the
61-
#' original S3 method.
62-
#' @param ... Not currently used.
63-
#'
64-
#' @export
65-
#' @return
66-
#' * [quantile_pred()] returns a vector of values associated with the
67-
#' quantile levels.
68-
#' * [extract_quantile_levels()] returns a numeric vector of levels.
69-
#' * [as_tibble()] returns a tibble with rows `".pred_quantile"`,
70-
#' `".quantile_levels"`, and `".row"`.
71-
#' * [as.matrix()] returns an unnamed matrix with rows as sames, columns as
72-
#' quantile levels, and entries are predictions.
73-
#' @examples
74-
#' .pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8))
75-
#'
76-
#' unclass(.pred_quantile)
1+
#' Reformat quantile predictions
772
#'
78-
#' # Access the underlying information
79-
#' extract_quantile_levels(.pred_quantile)
80-
#'
81-
#' # Matrix format
82-
#' as.matrix(.pred_quantile)
83-
#'
84-
#' # Tidy format
85-
#' tibble::as_tibble(.pred_quantile)
86-
quantile_pred <- function(values, quantile_levels = double()) {
87-
check_quantile_pred_inputs(values, quantile_levels)
88-
89-
quantile_levels <- vctrs::vec_cast(quantile_levels, double())
90-
num_lvls <- length(quantile_levels)
91-
92-
if (ncol(values) != num_lvls) {
93-
cli::cli_abort(
94-
"The number of columns in {.arg values} must be equal to the length of
95-
{.arg quantile_levels}."
96-
)
97-
}
98-
rownames(values) <- NULL
99-
colnames(values) <- NULL
100-
values <- lapply(vctrs::vec_chop(values), drop)
101-
new_quantile_pred(values, quantile_levels)
102-
}
103-
104-
check_quantile_pred_inputs <- function(values, levels, call = caller_env()) {
105-
if (any(is.na(levels))) {
106-
cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.",
107-
call = call)
108-
}
109-
110-
if (!is.matrix(values)) {
111-
cli::cli_abort(
112-
"{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.",
113-
call = call
114-
)
115-
}
116-
check_vector_probability(levels, arg = "quantile_levels", call = call)
117-
118-
if (is.unsorted(levels)) {
119-
cli::cli_abort(
120-
"{.arg quantile_levels} must be sorted in increasing order.",
121-
call = call
122-
)
123-
}
124-
invisible(NULL)
125-
}
126-
127-
#' @export
128-
format.quantile_pred <- function(x, ...) {
129-
quantile_levels <- attr(x, "quantile_levels")
130-
if (length(quantile_levels) == 1L) {
131-
x <- unlist(x)
132-
out <- round(x, 3L)
133-
out[is.na(x)] <- NA_real_
134-
} else {
135-
rng <- sapply(x, range, na.rm = TRUE)
136-
out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]")
137-
out[is.na(rng[1, ]) & is.na(rng[2, ])] <- NA_character_
138-
m <- median(x)
139-
out <- paste0("[", round(m, 3L), "]")
140-
}
141-
out
142-
}
143-
144-
#' @importFrom vctrs obj_print_footer
145-
#' @export
146-
vctrs::obj_print_footer
147-
148-
#' @export
149-
obj_print_footer.quantile_pred <- function(x, digits = 3, ...) {
150-
lvls <- attr(x, "quantile_levels")
151-
cat("# Quantile levels: ", format(lvls, digits = digits), "\n", sep = " ")
152-
}
153-
154-
check_vector_probability <- function(x, ...,
155-
allow_na = FALSE,
156-
allow_null = FALSE,
157-
arg = caller_arg(x),
158-
call = caller_env()) {
159-
for (d in x) {
160-
check_number_decimal(
161-
d, min = 0, max = 1,
162-
arg = arg, call = call,
163-
allow_na = allow_na,
164-
allow_null = allow_null,
165-
allow_infinite = FALSE
166-
)
167-
}
168-
}
169-
3+
#' @param x A matrix of predictions with rows as samples and columns as quantile
4+
#' levels.
5+
#' @param object A parsnip `model_fit` object from a quantile regression model.
6+
#' @keywords internal
1707
#' @export
171-
median.quantile_pred <- function(x, ...) {
172-
lvls <- attr(x, "quantile_levels")
173-
loc_median <- (abs(lvls - 0.5) < sqrt(.Machine$double.eps))
174-
if (any(loc_median)) {
175-
return(map_dbl(x, ~ .x[min(which(loc_median))]))
176-
}
177-
if (length(lvls) < 2 || min(lvls) > 0.5 || max(lvls) < 0.5) {
178-
return(rep(NA, vctrs::vec_size(x)))
179-
}
180-
map_dbl(x, ~ stats::approx(lvls, .x, xout = 0.5)$y)
181-
}
182-
183-
restructure_rq_pred <- function(x, object) {
8+
matrix_to_quantile_pred <- function(x, object) {
1849
if (!is.matrix(x)) {
18510
x <- as.matrix(x)
18611
}
18712
rownames(x) <- NULL
18813
n_pred_quantiles <- ncol(x)
189-
quantile_level <- object$spec$quantile_level
190-
191-
tibble::new_tibble(x = list(.pred_quantile = quantile_pred(x, quantile_level)))
192-
}
14+
quantile_levels <- object$spec$quantile_levels
19315

194-
#' @export
195-
#' @rdname quantile_pred
196-
extract_quantile_levels <- function(x) {
197-
if (!inherits(x, "quantile_pred")) {
198-
cli::cli_abort("{.arg x} should have class {.val quantile_pred}.")
199-
}
200-
attr(x, "quantile_levels")
201-
}
202-
203-
#' @export
204-
#' @rdname quantile_pred
205-
as_tibble.quantile_pred <-
206-
function (x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) {
207-
lvls <- attr(x, "quantile_levels")
208-
n_samp <- length(x)
209-
n_quant <- length(lvls)
210-
tibble::tibble(
211-
.pred_quantile = unlist(x),
212-
.quantile_levels = rep(lvls, n_samp),
213-
.row = rep(1:n_samp, each = n_quant)
214-
)
215-
}
216-
217-
#' @export
218-
#' @rdname quantile_pred
219-
as.matrix.quantile_pred <- function(x, ...) {
220-
num_samp <- length(x)
221-
matrix(unlist(x), nrow = num_samp)
16+
tibble::new_tibble(x = list(.pred_quantile = hardhat::quantile_pred(x, quantile_levels)))
22217
}

0 commit comments

Comments
 (0)