Skip to content

Commit 53263e9

Browse files
topepohfrick
andauthored
Divert linear regressions with poisson family to poisson_reg() (#1219)
* Changes for #956 * update news * point to parsnip PR instead of tune issue * move poisson checks to check_args * the snapshots should be created without glmnet installed * GHA for tests has glmnet installed (◔_◔) --------- Co-authored-by: Hannah Frick <[email protected]>
1 parent 8855842 commit 53263e9

File tree

4 files changed

+89
-0
lines changed

4 files changed

+89
-0
lines changed

NEWS.md

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
* Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083).
2525

26+
* If linear regression is requested with a Poisson family, an error will occur and refer the user to `poisson_reg()` (#1219).
27+
2628
* The deprecated function `rpart_train()` was removed after its deprecation period (#1044).
2729

2830
## Bug Fixes

R/linear_reg.R

+22
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ translate.linear_reg <- function(x, engine = x$engine, ...) {
7373
# evaluated value for the parameter.
7474
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
7575
}
76+
7677
x
7778
}
7879

@@ -113,5 +114,26 @@ check_args.linear_reg <- function(object, call = rlang::caller_env()) {
113114
check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture")
114115
check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty")
115116

117+
# ------------------------------------------------------------------------------
118+
# We want to avoid folks passing in a poisson family instead of using
119+
# poisson_reg(). It's hard to detect this.
120+
121+
is_fam <- names(object$eng_args) == "family"
122+
if (any(is_fam)) {
123+
eng_args <- rlang::eval_tidy(object$eng_args[[which(is_fam)]])
124+
if (is.function(eng_args)) {
125+
eng_args <- try(eng_args(), silent = TRUE)
126+
}
127+
if (inherits(eng_args, "family")) {
128+
eng_args <- eng_args$family
129+
}
130+
if (eng_args == "poisson") {
131+
cli::cli_abort(
132+
"A Poisson family was requested for {.fn linear_reg}. Please use
133+
{.fn poisson_reg} and the engines in the {.pkg poissonreg} package.",
134+
call = rlang::call2("linear_reg"))
135+
}
136+
}
137+
116138
invisible(object)
117139
}

tests/testthat/_snaps/linear_reg.md

+36
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,39 @@
139139
Error in `fit()`:
140140
! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1.
141141

142+
# prevent using a Poisson family
143+
144+
Code
145+
linear_reg(penalty = 1) %>% set_engine("glmnet", family = poisson) %>% fit(mpg ~
146+
., data = mtcars)
147+
Condition
148+
Error in `fit()`:
149+
! Please install the glmnet package to use this engine.
150+
151+
---
152+
153+
Code
154+
linear_reg(penalty = 1) %>% set_engine("glmnet", family = stats::poisson) %>%
155+
fit(mpg ~ ., data = mtcars)
156+
Condition
157+
Error in `fit()`:
158+
! Please install the glmnet package to use this engine.
159+
160+
---
161+
162+
Code
163+
linear_reg(penalty = 1) %>% set_engine("glmnet", family = stats::poisson()) %>%
164+
fit(mpg ~ ., data = mtcars)
165+
Condition
166+
Error in `fit()`:
167+
! Please install the glmnet package to use this engine.
168+
169+
---
170+
171+
Code
172+
linear_reg(penalty = 1) %>% set_engine("glmnet", family = "poisson") %>% fit(
173+
mpg ~ ., data = mtcars)
174+
Condition
175+
Error in `fit()`:
176+
! Please install the glmnet package to use this engine.
177+

tests/testthat/test-linear_reg.R

+29
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,32 @@ test_that("check_args() works", {
358358
}
359359
)
360360
})
361+
362+
363+
test_that("prevent using a Poisson family", {
364+
skip_if(rlang::is_installed("glmnet"))
365+
expect_snapshot(
366+
linear_reg(penalty = 1) %>%
367+
set_engine("glmnet", family = poisson) %>%
368+
fit(mpg ~ ., data = mtcars),
369+
error = TRUE
370+
)
371+
expect_snapshot(
372+
linear_reg(penalty = 1) %>%
373+
set_engine("glmnet", family = stats::poisson) %>%
374+
fit(mpg ~ ., data = mtcars),
375+
error = TRUE
376+
)
377+
expect_snapshot(
378+
linear_reg(penalty = 1) %>%
379+
set_engine("glmnet", family = stats::poisson()) %>%
380+
fit(mpg ~ ., data = mtcars),
381+
error = TRUE
382+
)
383+
expect_snapshot(
384+
linear_reg(penalty = 1) %>%
385+
set_engine("glmnet", family = "poisson") %>%
386+
fit(mpg ~ ., data = mtcars),
387+
error = TRUE
388+
)
389+
})

0 commit comments

Comments
 (0)