Skip to content

Commit 6168556

Browse files
committed
tests for quantreg
1 parent a153df0 commit 6168556

File tree

2 files changed

+90
-4
lines changed

2 files changed

+90
-4
lines changed

R/aaa_quantiles.R

+6-4
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,18 @@ check_quantile_level <- function(x, object, call) {
2323

2424
# Assumes the columns have the same order as quantile_level
2525
restructure_rq_pred <- function(x, object) {
26+
num_quantiles <- NCOL(x)
27+
if ( num_quantiles == 1L ){
28+
x <- matrix(x, ncol = 1)
29+
}
2630
n <- nrow(x)
27-
p <- ncol(x)
28-
# TODO check p = length(quantile_level)
29-
# check p = 1 case
31+
3032
quantile_level <- object$spec$quantile_level
3133
res <-
3234
tibble::tibble(
3335
.pred_quantile = as.vector(x),
3436
.quantile_level = rep(quantile_level, each = n),
35-
.row = rep(1:n, p))
37+
.row = rep(1:n, num_quantiles))
3638
res <- vctrs::vec_split(x = res[,1:2], by = res[, ".row"])
3739
res <- vctrs::vec_cbind(res$key, tibble::new_tibble(list(.pred_quantile = res$val)))
3840
res$.row <- NULL
+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
test_that('linear quantile regression via quantreg - single quantile', {
2+
skip_if_not_installed("quantreg")
3+
4+
data("Sacramento")
5+
6+
Sacramento_small <-
7+
Sacramento %>%
8+
dplyr::select(price, beds, baths, sqft, latitude, longitude)
9+
10+
sac_train <- Sacramento_small[-(1:5), ]
11+
sac_test <- Sacramento_small[ 1:5 , ]
12+
13+
one_quant <-
14+
linear_reg() %>%
15+
set_engine("quantreg") %>%
16+
set_mode("quantile regression", quantile_level = .5) %>%
17+
fit(price ~ ., data = sac_train)
18+
19+
expect_s3_class(one_quant, c("_rq", "model_fit"))
20+
21+
###
22+
23+
one_quant_pred <- predict(one_quant, new_data = sac_test)
24+
expect_true(nrow(one_quant_pred) == nrow(sac_test))
25+
expect_named(one_quant_pred, ".pred_quantile")
26+
expect_true(is.list(one_quant_pred[[1]]))
27+
expect_s3_class(one_quant_pred$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame"))
28+
expect_named(one_quant_pred$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level"))
29+
expect_true(nrow(one_quant_pred$.pred_quantile[[1]]) == 1L)
30+
31+
###
32+
33+
one_quant_one_row <- predict(one_quant, new_data = sac_test[1,])
34+
expect_true(nrow(one_quant_one_row) == 1L)
35+
expect_named(one_quant_one_row, ".pred_quantile")
36+
expect_true(is.list(one_quant_one_row[[1]]))
37+
expect_s3_class(one_quant_one_row$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame"))
38+
expect_named(one_quant_one_row$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level"))
39+
expect_true(nrow(one_quant_one_row$.pred_quantile[[1]]) == 1L)
40+
})
41+
42+
test_that('linear quantile regression via quantreg - multiple quantiles', {
43+
skip_if_not_installed("quantreg")
44+
45+
data("Sacramento")
46+
47+
Sacramento_small <-
48+
Sacramento %>%
49+
dplyr::select(price, beds, baths, sqft, latitude, longitude)
50+
51+
sac_train <- Sacramento_small[-(1:5), ]
52+
sac_test <- Sacramento_small[ 1:5 , ]
53+
54+
ten_quant <-
55+
linear_reg() %>%
56+
set_engine("quantreg") %>%
57+
set_mode("quantile regression", quantile_level = (0:9)/9) %>%
58+
fit(price ~ ., data = sac_train)
59+
60+
expect_s3_class(ten_quant, c("_rq", "model_fit"))
61+
62+
###
63+
64+
ten_quant_pred <- predict(ten_quant, new_data = sac_test)
65+
expect_true(nrow(ten_quant_pred) == nrow(sac_test))
66+
expect_named(ten_quant_pred, ".pred_quantile")
67+
expect_true(is.list(ten_quant_pred[[1]]))
68+
expect_s3_class(ten_quant_pred$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame"))
69+
expect_named(ten_quant_pred$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level"))
70+
expect_true(nrow(ten_quant_pred$.pred_quantile[[1]]) == 10L)
71+
72+
###
73+
74+
ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1,])
75+
expect_true(nrow(ten_quant_one_row) == 1L)
76+
expect_named(ten_quant_one_row, ".pred_quantile")
77+
expect_true(is.list(ten_quant_one_row[[1]]))
78+
expect_s3_class(ten_quant_one_row$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame"))
79+
expect_named(ten_quant_one_row$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level"))
80+
expect_true(nrow(ten_quant_one_row$.pred_quantile[[1]]) == 10L)
81+
})
82+
83+
84+

0 commit comments

Comments
 (0)