Skip to content

Commit e9354e7

Browse files
don't turn sparse matrix into dense matrix for glmnet prediction (#1210)
* don't turn sparse matrix into dense matrix for glmnet prediction * remember to do subsetting * don't add random newlines * test glmnet predict doesn't remove sparseness * regenerate snapshots to resolve merge conflicts --------- Co-authored-by: ‘topepo’ <[email protected]>
1 parent 53263e9 commit e9354e7

File tree

4 files changed

+47
-1
lines changed

4 files changed

+47
-1
lines changed

R/glmnet-engines.R

+10
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,16 @@ predict_raw._glmnetfit <- predict_raw_glmnet
138138
unname(x[, 1])
139139
}
140140

141+
organize_glmnet_pre_pred <- function(x, object) {
142+
x <- x[, rownames(object$fit$beta), drop = FALSE]
143+
if (is_sparse_matrix(x)) {
144+
return(x)
145+
}
146+
147+
as.matrix(x)
148+
}
149+
150+
141151
organize_glmnet_class <- function(x, object) {
142152
prob_to_class_2(x[, 1], object)
143153
}

R/linear_reg_data.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ set_pred(
250250
args =
251251
list(
252252
object = expr(object$fit),
253-
newx = expr(as.matrix(new_data[, rownames(object$fit$beta), drop = FALSE])),
253+
newx = expr(organize_glmnet_pre_pred(new_data, object)),
254254
type = "response",
255255
s = expr(object$spec$args$penalty)
256256
)

tests/testthat/_snaps/sparsevctrs.md

+8
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@
127127
Error in `maybe_sparse_matrix()`:
128128
! no sparse vectors detected
129129

130+
# we don't run as.matrix() on sparse matrix for glmnet pred #1210
131+
132+
Code
133+
predict(lm_fit, hotel_data)
134+
Condition
135+
Error in `predict.elnet()`:
136+
! data is sparse
137+
130138
# fit() errors if sparse matrix has no colnames
131139

132140
Code

tests/testthat/test-sparsevctrs.R

+28
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,34 @@ test_that("maybe_sparse_matrix() is used correctly", {
314314
)
315315
})
316316

317+
test_that("we don't run as.matrix() on sparse matrix for glmnet pred #1210", {
318+
skip_if_not_installed("glmnet")
319+
320+
local_mocked_bindings(
321+
predict.elnet = function(object, newx, ...) {
322+
if (is_sparse_matrix(newx)) {
323+
stop("data is sparse")
324+
} else {
325+
stop("data isn't sparse (should not happen)")
326+
}
327+
},
328+
.package = "glmnet"
329+
)
330+
331+
hotel_data <- sparse_hotel_rates()
332+
333+
spec <- linear_reg(penalty = 0) %>%
334+
set_mode("regression") %>%
335+
set_engine("glmnet")
336+
337+
lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
338+
339+
expect_snapshot(
340+
error = TRUE,
341+
predict(lm_fit, hotel_data)
342+
)
343+
})
344+
317345
test_that("fit() errors if sparse matrix has no colnames", {
318346
hotel_data <- sparse_hotel_rates()
319347
colnames(hotel_data) <- NULL

0 commit comments

Comments
 (0)