Skip to content

Commit 556c732

Browse files
Support All keras activations functions (#1244)
* allow all keras activations * update news * add tests --------- Co-authored-by: ‘topepo’ <[email protected]>
1 parent e9354e7 commit 556c732

File tree

5 files changed

+70
-5
lines changed

5 files changed

+70
-5
lines changed

NAMESPACE

+1
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ export(glm_grouped)
253253
export(has_multi_predict)
254254
export(importance_weights)
255255
export(is_varying)
256+
export(keras_activations)
256257
export(keras_mlp)
257258
export(keras_predict_classes)
258259
export(knit_engine_docs)

NEWS.md

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).
1717

18+
* `mlp()` with `keras` engine now work for all activation functions currently supported by `keras` (#1127).
19+
1820
## Other Changes
1921

2022
* Transitioned package errors and warnings to use cli (#1147 and #1148 by @shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160, #1161, #1081).

R/mlp.R

+20-2
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,13 @@ keras_mlp <-
192192
seeds = sample.int(10^5, size = 3),
193193
...) {
194194

195-
act_funs <- c("linear", "softmax", "relu", "elu", "tanh")
196-
rlang::arg_match(activation, act_funs)
195+
allowed_keras_activation <- keras_activations()
196+
good_activation <- activation %in% allowed_keras_activation
197+
if (!all(good_activation)) {
198+
cli::cli_abort(
199+
"{.arg activation} should be one of: {allowed_activation}."
200+
)
201+
}
197202

198203
if (penalty > 0 & dropout > 0) {
199204
cli::cli_abort("Please use either dropout or weight decay.", call = NULL)
@@ -344,6 +349,19 @@ mlp_num_weights <- function(p, hidden_units, classes) {
344349
((p + 1) * hidden_units) + ((hidden_units+1) * classes)
345350
}
346351

352+
allowed_keras_activation <-
353+
c("elu", "exponential", "gelu", "hard_sigmoid", "linear", "relu", "selu",
354+
"sigmoid", "softmax", "softplus", "softsign", "swish", "tanh")
355+
356+
#' Activation functions for neural networks in keras
357+
#'
358+
#' @keywords internal
359+
#' @return A character vector of values.
360+
#' @export
361+
keras_activations <- function() {
362+
allowed_keras_activation
363+
}
364+
347365
## -----------------------------------------------------------------------------
348366

349367
#' @importFrom purrr map

man/keras_activations.Rd

+15
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-mlp_keras.R

+32-3
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,10 @@ car_basic <- mlp(mode = "regression", epochs = 10) %>%
149149

150150
bad_keras_reg <-
151151
mlp(mode = "regression") %>%
152-
set_engine("keras", min.node.size = -10)
152+
set_engine("keras", min.node.size = -10, verbose = 0)
153153

154154
# ------------------------------------------------------------------------------
155155

156-
157156
test_that('keras execution, regression', {
158157
skip_on_cran()
159158
skip_if_not_installed("keras")
@@ -211,7 +210,6 @@ test_that('keras regression prediction', {
211210
keras::backend()$clear_session()
212211
})
213212

214-
215213
# ------------------------------------------------------------------------------
216214

217215
test_that('multivariate nnet formula', {
@@ -247,3 +245,34 @@ test_that('multivariate nnet formula', {
247245

248246
keras::backend()$clear_session()
249247
})
248+
249+
# ------------------------------------------------------------------------------
250+
251+
test_that('all keras activation functions', {
252+
skip_on_cran()
253+
skip_if_not_installed("keras")
254+
skip_if_not_installed("modeldata")
255+
skip_if(!is_tf_ok())
256+
257+
act <- parsnip:::keras_activations()
258+
259+
test_act <- function(fn) {
260+
set.seed(1)
261+
try(
262+
mlp(mode = "classification", hidden_units = 2, penalty = 0.01, epochs = 2,
263+
activation = !!fn) %>%
264+
set_engine("keras", verbose = 0) %>%
265+
parsnip::fit(Class ~ A + B, data = modeldata::two_class_dat),
266+
silent = TRUE)
267+
268+
}
269+
test_act_sshhh <- purrr::quietly(test_act)
270+
271+
for (i in act) {
272+
keras::backend()$clear_session()
273+
act_res <- test_act_sshhh(i)
274+
expect_s3_class(act_res$result, "model_fit")
275+
keras::backend()$clear_session()
276+
}
277+
278+
})

0 commit comments

Comments
 (0)