|
1 |
| -# Helpers for quantile regression models |
2 |
| - |
3 |
| -check_quantile_level <- function(x, object, call) { |
4 |
| - if (object$mode != "quantile regression") { |
5 |
| - return(invisible(TRUE)) |
6 |
| - } else { |
7 |
| - if (is.null(x)) { |
8 |
| - cli::cli_abort("In {.fn check_mode}, at least one value of |
9 |
| - {.arg quantile_level} must be specified for quantile regression models.") |
10 |
| - } |
11 |
| - } |
12 |
| - if (any(is.na(x))) { |
13 |
| - cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.", |
14 |
| - call = call) |
15 |
| - } |
16 |
| - x <- sort(unique(x)) |
17 |
| - check_vector_probability(x, arg = "quantile_level", call = call) |
18 |
| - x |
19 |
| -} |
20 |
| - |
21 |
| - |
22 |
| -# ------------------------------------------------------------------------- |
23 |
| -# A column vector of quantiles with an attribute |
24 |
| - |
25 |
| -#' @importFrom vctrs vec_ptype_abbr |
26 |
| -#' @export |
27 |
| -vctrs::vec_ptype_abbr |
28 |
| - |
29 |
| -#' @importFrom vctrs vec_ptype_full |
30 |
| -#' @export |
31 |
| -vctrs::vec_ptype_full |
32 |
| - |
33 |
| - |
34 |
| -#' @export |
35 |
| -vec_ptype_abbr.quantile_pred <- function(x, ...) { |
36 |
| - n_lvls <- length(attr(x, "quantile_levels")) |
37 |
| - cli::format_inline("qtl{?s}({n_lvls})") |
38 |
| -} |
39 |
| - |
40 |
| -#' @export |
41 |
| -vec_ptype_full.quantile_pred <- function(x, ...) "quantiles" |
42 |
| - |
43 |
| -new_quantile_pred <- function(values = list(), quantile_levels = double()) { |
44 |
| - quantile_levels <- vctrs::vec_cast(quantile_levels, double()) |
45 |
| - vctrs::new_vctr( |
46 |
| - values, quantile_levels = quantile_levels, class = "quantile_pred" |
47 |
| - ) |
48 |
| -} |
49 |
| - |
50 |
| -#' Create a vector containing sets of quantiles |
51 |
| -#' |
52 |
| -#' [quantile_pred()] is a special vector class used to efficiently store |
53 |
| -#' predictions from a quantile regression model. It requires the same quantile |
54 |
| -#' levels for each row being predicted. |
55 |
| -#' |
56 |
| -#' @param values A matrix of values. Each column should correspond to one of |
57 |
| -#' the quantile levels. |
58 |
| -#' @param quantile_levels A vector of probabilities corresponding to `values`. |
59 |
| -#' @param x An object produced by [quantile_pred()]. |
60 |
| -#' @param .rows,.name_repair,rownames Arguments not used but required by the |
61 |
| -#' original S3 method. |
62 |
| -#' @param ... Not currently used. |
63 |
| -#' |
64 |
| -#' @export |
65 |
| -#' @return |
66 |
| -#' * [quantile_pred()] returns a vector of values associated with the |
67 |
| -#' quantile levels. |
68 |
| -#' * [extract_quantile_levels()] returns a numeric vector of levels. |
69 |
| -#' * [as_tibble()] returns a tibble with rows `".pred_quantile"`, |
70 |
| -#' `".quantile_levels"`, and `".row"`. |
71 |
| -#' * [as.matrix()] returns an unnamed matrix with rows as sames, columns as |
72 |
| -#' quantile levels, and entries are predictions. |
73 |
| -#' @examples |
74 |
| -#' .pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) |
75 |
| -#' |
76 |
| -#' unclass(.pred_quantile) |
| 1 | +#' Reformat quantile predictions |
77 | 2 | #'
|
78 |
| -#' # Access the underlying information |
79 |
| -#' extract_quantile_levels(.pred_quantile) |
80 |
| -#' |
81 |
| -#' # Matrix format |
82 |
| -#' as.matrix(.pred_quantile) |
83 |
| -#' |
84 |
| -#' # Tidy format |
85 |
| -#' tibble::as_tibble(.pred_quantile) |
86 |
| -quantile_pred <- function(values, quantile_levels = double()) { |
87 |
| - check_quantile_pred_inputs(values, quantile_levels) |
88 |
| - |
89 |
| - quantile_levels <- vctrs::vec_cast(quantile_levels, double()) |
90 |
| - num_lvls <- length(quantile_levels) |
91 |
| - |
92 |
| - if (ncol(values) != num_lvls) { |
93 |
| - cli::cli_abort( |
94 |
| - "The number of columns in {.arg values} must be equal to the length of |
95 |
| - {.arg quantile_levels}." |
96 |
| - ) |
97 |
| - } |
98 |
| - rownames(values) <- NULL |
99 |
| - colnames(values) <- NULL |
100 |
| - values <- lapply(vctrs::vec_chop(values), drop) |
101 |
| - new_quantile_pred(values, quantile_levels) |
102 |
| -} |
103 |
| - |
104 |
| -check_quantile_pred_inputs <- function(values, levels, call = caller_env()) { |
105 |
| - if (any(is.na(levels))) { |
106 |
| - cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.", |
107 |
| - call = call) |
108 |
| - } |
109 |
| - |
110 |
| - if (!is.matrix(values)) { |
111 |
| - cli::cli_abort( |
112 |
| - "{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.", |
113 |
| - call = call |
114 |
| - ) |
115 |
| - } |
116 |
| - check_vector_probability(levels, arg = "quantile_levels", call = call) |
117 |
| - |
118 |
| - if (is.unsorted(levels)) { |
119 |
| - cli::cli_abort( |
120 |
| - "{.arg quantile_levels} must be sorted in increasing order.", |
121 |
| - call = call |
122 |
| - ) |
123 |
| - } |
124 |
| - invisible(NULL) |
125 |
| -} |
126 |
| - |
127 |
| -#' @export |
128 |
| -format.quantile_pred <- function(x, ...) { |
129 |
| - quantile_levels <- attr(x, "quantile_levels") |
130 |
| - if (length(quantile_levels) == 1L) { |
131 |
| - x <- unlist(x) |
132 |
| - out <- round(x, 3L) |
133 |
| - out[is.na(x)] <- NA_real_ |
134 |
| - } else { |
135 |
| - rng <- sapply(x, range, na.rm = TRUE) |
136 |
| - out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]") |
137 |
| - out[is.na(rng[1, ]) & is.na(rng[2, ])] <- NA_character_ |
138 |
| - m <- median(x) |
139 |
| - out <- paste0("[", round(m, 3L), "]") |
140 |
| - } |
141 |
| - out |
142 |
| -} |
143 |
| - |
144 |
| -#' @importFrom vctrs obj_print_footer |
145 |
| -#' @export |
146 |
| -vctrs::obj_print_footer |
147 |
| - |
148 |
| -#' @export |
149 |
| -obj_print_footer.quantile_pred <- function(x, digits = 3, ...) { |
150 |
| - lvls <- attr(x, "quantile_levels") |
151 |
| - cat("# Quantile levels: ", format(lvls, digits = digits), "\n", sep = " ") |
152 |
| -} |
153 |
| - |
154 |
| -check_vector_probability <- function(x, ..., |
155 |
| - allow_na = FALSE, |
156 |
| - allow_null = FALSE, |
157 |
| - arg = caller_arg(x), |
158 |
| - call = caller_env()) { |
159 |
| - for (d in x) { |
160 |
| - check_number_decimal( |
161 |
| - d, min = 0, max = 1, |
162 |
| - arg = arg, call = call, |
163 |
| - allow_na = allow_na, |
164 |
| - allow_null = allow_null, |
165 |
| - allow_infinite = FALSE |
166 |
| - ) |
167 |
| - } |
168 |
| -} |
169 |
| - |
| 3 | +#' @param x A matrix of predictions with rows as samples and columns as quantile |
| 4 | +#' levels. |
| 5 | +#' @param object A parsnip `model_fit` object from a quantile regression model. |
| 6 | +#' @keywords internal |
170 | 7 | #' @export
|
171 |
| -median.quantile_pred <- function(x, ...) { |
172 |
| - lvls <- attr(x, "quantile_levels") |
173 |
| - loc_median <- (abs(lvls - 0.5) < sqrt(.Machine$double.eps)) |
174 |
| - if (any(loc_median)) { |
175 |
| - return(map_dbl(x, ~ .x[min(which(loc_median))])) |
176 |
| - } |
177 |
| - if (length(lvls) < 2 || min(lvls) > 0.5 || max(lvls) < 0.5) { |
178 |
| - return(rep(NA, vctrs::vec_size(x))) |
179 |
| - } |
180 |
| - map_dbl(x, ~ stats::approx(lvls, .x, xout = 0.5)$y) |
181 |
| -} |
182 |
| - |
183 |
| -restructure_rq_pred <- function(x, object) { |
| 8 | +matrix_to_quantile_pred <- function(x, object) { |
184 | 9 | if (!is.matrix(x)) {
|
185 | 10 | x <- as.matrix(x)
|
186 | 11 | }
|
187 | 12 | rownames(x) <- NULL
|
188 | 13 | n_pred_quantiles <- ncol(x)
|
189 |
| - quantile_level <- object$spec$quantile_level |
190 |
| - |
191 |
| - tibble::new_tibble(x = list(.pred_quantile = quantile_pred(x, quantile_level))) |
192 |
| -} |
| 14 | + quantile_levels <- object$spec$quantile_levels |
193 | 15 |
|
194 |
| -#' @export |
195 |
| -#' @rdname quantile_pred |
196 |
| -extract_quantile_levels <- function(x) { |
197 |
| - if (!inherits(x, "quantile_pred")) { |
198 |
| - cli::cli_abort("{.arg x} should have class {.val quantile_pred}.") |
199 |
| - } |
200 |
| - attr(x, "quantile_levels") |
201 |
| -} |
202 |
| - |
203 |
| -#' @export |
204 |
| -#' @rdname quantile_pred |
205 |
| -as_tibble.quantile_pred <- |
206 |
| - function (x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) { |
207 |
| - lvls <- attr(x, "quantile_levels") |
208 |
| - n_samp <- length(x) |
209 |
| - n_quant <- length(lvls) |
210 |
| - tibble::tibble( |
211 |
| - .pred_quantile = unlist(x), |
212 |
| - .quantile_levels = rep(lvls, n_samp), |
213 |
| - .row = rep(1:n_samp, each = n_quant) |
214 |
| - ) |
215 |
| - } |
216 |
| - |
217 |
| -#' @export |
218 |
| -#' @rdname quantile_pred |
219 |
| -as.matrix.quantile_pred <- function(x, ...) { |
220 |
| - num_samp <- length(x) |
221 |
| - matrix(unlist(x), nrow = num_samp) |
| 16 | + tibble::new_tibble(x = list(.pred_quantile = hardhat::quantile_pred(x, quantile_levels))) |
222 | 17 | }
|
0 commit comments