diff --git a/DESCRIPTION b/DESCRIPTION index 8e4ec8b..2982fde 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,22 +1,27 @@ Package: ordered -Title: Wrappers for Ordinal Classification Models +Title: 'parsnip' Engines and Wrappers for Ordinal Classification Models Version: 0.0.0.9000 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-2402-136X")), + person("Jason Cory", "Brunson", , "cornelioid@gmail.com", role = "aut", + comment = c(ORCID = "0000-0003-3126-9494")), person("Posit Software PBC", role = "cph") ) Description: Bindings for ordinal classification models for use with the - 'parsnip' package, such as ordinal random forests by Hornung R. (2020) - and others. + 'parsnip' package, such as the proportional odds logistic regression + implemented in 'MASS' and the ordinal random forests of Hornung (2020) + . License: MIT + file LICENSE Depends: parsnip (>= 1.2.1.9003) Imports: cli, dplyr, - rlang (>= 1.1.4) + rlang (>= 1.1.4), + tibble Suggests: + MASS, ordinalForest, QSARdata, spelling, diff --git a/NEWS.md b/NEWS.md new file mode 100644 index 0000000..fd71664 --- /dev/null +++ b/NEWS.md @@ -0,0 +1,3 @@ +# ordered (development version) + +* Initial CRAN submission. diff --git a/R/ordered-package.R b/R/ordered-package.R index 7e37562..c5763ab 100644 --- a/R/ordered-package.R +++ b/R/ordered-package.R @@ -1,3 +1,45 @@ +#' {ordered}: parsnip Engines for Ordinal Regression Models +#' +#' {ordered} provides engines for ordinal regression models for the {parsnip} +#' package. The models may have cumulative, sequential, or adjacent-category +#' structure, and in future these may be disaggregated into separate model +#' types. A vignette will provide thorough illustrations of {ordered} +#' functionality. See below for examples of fitting ordinal regression models +#' with {ordered}. +#' +#' @examples +#' if (rlang::is_installed("MASS")) { +#' +#' # Weighted sample +#' +#' set.seed(561246) +#' house_sub <- MASS::housing |> +#' dplyr::sample_n(size = 120, replace = TRUE, weight = Freq) |> +#' subset(select = -Freq) +#' train_inds <- sample(120, 80) +#' house_train <- house_sub[train_inds, ] +#' house_test <- house_sub[-train_inds, ] +#' +#' # Cumulative-link proportional-odds probit regression model +#' +#' fit_cpop <- ordinal_reg() |> +#' set_engine("polr") |> +#' set_args(method = "probit") |> +#' fit(Sat ~ Infl + Type + Cont, data = house_train) +#' predict(fit_cpop, house_test, type = "prob") +#' +#' if (rlang::is_installed("ordinalForest")) { +#' +#' # Ordinal forest +#' +#' fit_orf <- rand_forest(mode = "classification") |> +#' set_engine("ordinalForest") |> +#' fit(Sat ~ Infl + Type + Cont, data = house_train) +#' predict(fit_orf, house_test, type = "prob") +#' +#' } +#' } +#' #' @keywords internal "_PACKAGE" diff --git a/R/ordinal_reg_data.R b/R/ordinal_reg_data.R new file mode 100644 index 0000000..5703d87 --- /dev/null +++ b/R/ordinal_reg_data.R @@ -0,0 +1,87 @@ +# These functions define the ordinal regression models. +# They are executed when this package is loaded via `.onLoad()` +# and modify the {parsnip} package's model environment. + +# These functions are tested indirectly when the models are used. +# Since they are added to the parsnip model database on startup execution, +# they can't be test-executed so are excluded from coverage stats. + +# nocov start + +make_ordinal_reg_polr <- function() { + + parsnip::set_model_engine("ordinal_reg", "classification", "polr") + parsnip::set_dependency( + "ordinal_reg", + eng = "polr", + pkg = "ordered", + mode = "classification" + ) + + parsnip::set_fit( + model = "ordinal_reg", + eng = "polr", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "MASS", fun = "polr"), + defaults = list( + method = "logistic" + ) + ) + ) + + parsnip::set_encoding( + model = "ordinal_reg", + eng = "polr", + mode = "classification", + options = list( + predictor_indicators = "traditional", + compute_intercept = TRUE, + remove_intercept = FALSE, + allow_sparse_x = FALSE + ) + ) + + parsnip::set_pred( + model = "ordinal_reg", + eng = "polr", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "class" + ) + ) + ) + + parsnip::set_pred( + model = "ordinal_reg", + eng = "polr", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + tibble::as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probs" + ) + ) + ) + +} + +# nocov end diff --git a/R/zzz.R b/R/zzz.R index 07f1322..f175ccd 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -5,8 +5,8 @@ # been loaded. .onLoad <- function(libname, pkgname) { + make_ordinal_reg_polr() make_rand_forest_ordinalForest() } # nocov end - diff --git a/man/ordered-package.Rd b/man/ordered-package.Rd index dd39024..a44cf0f 100644 --- a/man/ordered-package.Rd +++ b/man/ordered-package.Rd @@ -4,13 +4,50 @@ \name{ordered-package} \alias{ordered} \alias{ordered-package} -\title{ordered: Wrappers for Ordinal Classification Models} +\title{{ordered}: parsnip Engines for Ordinal Regression Models} \description{ -Bindings for ordinal classification models for use with the 'parsnip' package, such as ordinal random forests by Hornung R. (2020) \doi{10.1007/s00357-018-9302-x} and others. +{ordered} provides engines for ordinal regression models for the {parsnip} +package. The models may have cumulative, sequential, or adjacent-category +structure, and in future these may be disaggregated into separate model +types. A vignette will provide thorough illustrations of {ordered} +functionality. See below for examples of fitting ordinal regression models +with {censored}. +} +\examples{ +# Weighted sample + +set.seed(561246) +house_sub <- MASS::housing |> + dplyr::sample_n(size = 120, replace = TRUE, weight = Freq) |> + subset(select = -Freq) +train_inds <- sample(120, 80) +house_train <- house_sub[train_inds, ] +house_test <- house_sub[-train_inds, ] + +# Cumulative-link proportional-odds probit regression model + +fit_cpop <- ordinal_reg() |> + set_engine("polr") |> + set_args(method = "probit") |> + fit(Sat ~ Infl + Type + Cont, data = house_train) +predict(fit_cpop, house_test, type = "prob") + +# Ordinal forest + +fit_orf <- rand_forest(mode = "classification") |> + set_engine("ordinalForest") |> + fit(Sat ~ Infl + Type + Cont, data = house_train) +predict(fit_orf, house_test, type = "prob") + } \author{ \strong{Maintainer}: Max Kuhn \email{max@posit.co} (\href{https://orcid.org/0000-0003-2402-136X}{ORCID}) +Authors: +\itemize{ + \item Jason Cory Brunson \email{cornelioid@gmail.com} (\href{https://orcid.org/0000-0003-3126-9494}{ORCID}) +} + Other contributors: \itemize{ \item Posit Software PBC [copyright holder] diff --git a/tests/testthat/helper-data.R b/tests/testthat/helper-data.R index d676f34..6252b66 100644 --- a/tests/testthat/helper-data.R +++ b/tests/testthat/helper-data.R @@ -1,3 +1,5 @@ +# https://testthat.r-lib.org/articles/skipping.html#helpers + if (rlang::is_installed("QSARdata")) { library(dplyr) data(caco, package = "QSARdata") @@ -12,3 +14,13 @@ if (rlang::is_installed("QSARdata")) { caco_train <- caco_dat[-c(1:2, 21:22, 41:42), ] caco_test <- caco_dat[ c(1:2, 21:22, 41:42), ] } + +get_house <- function() { + require(MASS) + set.seed(581837) + house_data <- MASS::housing + house_sub <- + house_data[sample(72, 120, replace = TRUE, prob = house_data$Freq), ] |> + subset(select = -Freq) + list(data = house_data, sub = house_sub) +} diff --git a/tests/testthat/test-ordinal_reg.R b/tests/testthat/test-ordinal_reg.R new file mode 100644 index 0000000..9fc2566 --- /dev/null +++ b/tests/testthat/test-ordinal_reg.R @@ -0,0 +1,9 @@ +# Test model type and engine arguments here rather than in {parsnip} if they +# require engines to be loaded. + +test_that("check_args() works", { + skip_if_not_installed("parsnip", "1.2.1.9003") + + # Here for completeness, no checking is done + expect_true(TRUE) +}) diff --git a/tests/testthat/test-ordinal_reg_polr.R b/tests/testthat/test-ordinal_reg_polr.R new file mode 100644 index 0000000..81fd380 --- /dev/null +++ b/tests/testthat/test-ordinal_reg_polr.R @@ -0,0 +1,97 @@ + +# model: basic ----------------------------------------------------------------- + +test_that("model object", { + skip_if_not_installed("MASS") + house_sub <- get_house()$sub + + orig_fit <- MASS::polr( + Sat ~ Type + Infl + Cont, + data = house_sub, + model = TRUE + ) + + tidy_spec <- ordinal_reg() |> + set_engine("polr") |> + set_mode("classification") + tidy_fit <- fit(tidy_spec, Sat ~ Type + Infl + Cont, data = house_sub) + + # remove `call` from comparison + orig_fit$call <- NULL + tidy_fit$fit$call <- NULL + + expect_equal( + orig_fit, + tidy_fit$fit, + ignore_formula_env = TRUE + ) +}) + +# model: case weights ---------------------------------------------------------- + +test_that("case weights", { + skip_if_not_installed("MASS") + house_data <- get_house()$data + + orig_fit <- MASS::polr( + Sat ~ Type + Infl + Cont, + data = house_data, + weights = Freq, + model = TRUE + ) + + tidy_spec <- ordinal_reg() |> + set_engine("polr") |> + set_mode("classification") + tidy_data <- transform(house_data, Freq = frequency_weights(Freq)) + tidy_fit <- fit( + tidy_spec, + Sat ~ Type + Infl + Cont, + data = tidy_data, + case_weights = tidy_data$Freq + ) + + orig_fit$call <- NULL + tidy_fit$fit$call <- NULL + + expect_equal( + orig_fit, + tidy_fit$fit, + ignore_formula_env = TRUE + ) +}) + +# prediction: probability ------------------------------------------------------ + +test_that("probability prediction", { + skip_if_not_installed("MASS") + house_sub <- get_house()$sub + + tidy_fit <- ordinal_reg() |> + set_engine("polr") |> + fit(Sat ~ Type + Cont, data = house_sub) + + orig_pred <- predict(tidy_fit$fit, newdata = house_sub, type = "probs") + orig_pred <- tibble::as_tibble(orig_pred) + orig_pred <- set_names(orig_pred, paste0(".pred_", names(orig_pred))) + tidy_pred <- predict(tidy_fit, house_sub, type = "prob") + expect_equal(orig_pred, tidy_pred) +}) + +# prediction: class ------------------------------------------------------------ + +test_that("class prediction", { + skip_if_not_installed("MASS") + house_sub <- get_house()$sub + + tidy_fit <- ordinal_reg() |> + set_engine("polr") |> + fit(Sat ~ Infl + Cont, data = house_sub) + + orig_pred <- predict(tidy_fit$fit, house_sub) + # NB: `MASS:::predict.polr()` strips order from `object$model$`. + orig_pred <- ordered(unname(orig_pred), levels(orig_pred)) + orig_pred <- tibble::tibble(.pred_class = orig_pred) + tidy_pred <- predict(tidy_fit, house_sub) + expect_equal(orig_pred, tidy_pred) +})