Skip to content

Commit 9cfcb7f

Browse files
authored
fixes for #1236 (#1237)
1 parent 082351b commit 9cfcb7f

File tree

4 files changed

+27
-3
lines changed

4 files changed

+27
-3
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434
* Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166).
3535

36+
* `tunable()` now references a dials object for the `mixture` parameter (#1236)
37+
3638
## Breaking Change
3739

3840
* For quantile prediction, the `quantile` argument to `predict()` has been deprecate in facor of `quantile_levels`. This does not affect models with mode `"quantile regression"`.

R/tunable.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,9 @@ brulee_mlp_engine_args <-
206206
"largest", list(pkg = "dials", fun = "rate_largest"),
207207
"rate_schedule", list(pkg = "dials", fun = "rate_schedule"),
208208
"step_size", list(pkg = "dials", fun = "rate_step_size"),
209-
"steps", list(pkg = "dials", fun = "rate_steps")
209+
"mixture", list(pkg = "dials", fun = "mixture")
210210
) %>%
211-
dplyr::mutate(,
212-
source = "model_spec",
211+
dplyr::mutate(source = "model_spec",
213212
component = "mlp",
214213
component_id = "engine"
215214
)

parsnip.Rproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
Version: 1.0
2+
ProjectId: 7f6c9ff5-6b9a-4235-8666-12db5ef65d49
23

34
RestoreWorkspace: No
45
SaveWorkspace: No

tests/testthat/test-tunable.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
test_that('brulee has mixture object', {
2+
# for issue 1236
3+
mlp_spec <-
4+
mlp(
5+
hidden_units = tune(),
6+
activation = tune(),
7+
penalty = tune(),
8+
learn_rate = tune(),
9+
epoch = 2000
10+
) %>%
11+
set_mode("regression") %>%
12+
set_engine("brulee",
13+
stop_iter = tune(),
14+
mixture = tune(),
15+
rate_schedule = tune())
16+
17+
brulee_res <- tunable(mlp_spec)
18+
19+
expect_true(
20+
length(brulee_res$call_info[brulee_res$name == "mixture"]) > 0
21+
)
22+
})

0 commit comments

Comments
 (0)