diff --git a/R/arguments.R b/R/arguments.R index 5b88b0c..66646f1 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -41,32 +41,6 @@ make_x_call <- function(object, target) { fit_call } -make_form_call <- function(object, env = NULL) { - fit_args <- object$method$fit$args - - # Get the arguments related to data: - if (is.null(object$method$fit$data)) { - data_args <- c(formula = "formula", data = "data") - } else { - data_args <- object$method$fit$data - } - - # add data arguments - for (i in seq_along(data_args)) { - fit_args[[unname(data_args[i])]] <- sym(names(data_args)[i]) - } - - # sub in actual formula - fit_args[[unname(data_args["formula"])]] <- env$formula - - fit_call <- make_call( - fun = object$method$fit$func["fun"], - ns = object$method$fit$func["pkg"], - fit_args - ) - fit_call -} - #' Change arguments of a cluster specification #' #' @inheritParams parsnip::set_args diff --git a/R/convert_data.R b/R/convert_data.R index b502b92..23b1f1d 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -8,14 +8,12 @@ #' internals of `lm()` (and also see the notes at #' https://developer.r-project.org/model-fitting-functions.html). #' -#' `.convert_form_to_x_fit()` and `.convert_x_to_form_fit()` are for when the -#' data are created for modeling. `.convert_form_to_x_fit()` saves both the -#' data objects as well as the objects needed when new data are predicted -#' (e.g. `terms`, etc.). +#' `.convert_form_to_x_fit()` is for when the data are created for modeling. +#' It saves both the data objects as well as the objects needed when new data +#' are predicted (e.g. `terms`, etc.). #' -#' `.convert_form_to_x_new()` and `.convert_x_to_form_new()` are used when new -#' samples are being predicted and only require the predictors to be -#' available. +#' `.convert_form_to_x_new()` is used when new samples are being predicted and +#' only requires the predictors to be available. #' #' @param data A data frame containing all relevant variables (e.g. predictors, #' case weights, etc). @@ -144,68 +142,6 @@ local_one_hot_contrasts <- function(frame = rlang::caller_env()) { rlang::local_options(contrasts = contrasts, .frame = frame) } -# ------------------------------------------------------------------------------ - -# The other direction where we make a formula from the data -# objects - -# TODO slots for other roles -#' @param weights A numeric vector containing the weights. -#' @inheritParams fit.cluster_spec -#' @inheritParams .convert_form_to_x_fit -#' @rdname convert_helpers -#' @keywords internal -.convert_x_to_form_fit <- function(x, weights = NULL, remove_intercept = TRUE) { - if (is.vector(x)) { - cli::cli_abort("{.arg x} cannot be a vector.") - } - - if (remove_intercept) { - x <- x[, colnames(x) != "(Intercept)", drop = FALSE] - } - - rn <- rownames(x) - - if (!is.data.frame(x)) { - x <- as.data.frame(x) - } - - x_var <- names(x) - form <- make_formula(names(x)) - - x <- bind_cols(x, y) - if (!is.null(rn) && !inherits(x, "tbl_df")) { - rownames(x) <- rn - } - - if (!is.null(weights)) { - if (!is.numeric(weights)) { - cli::cli_abort("The {.arg weights} must be a numeric vector.") - } - if (length(weights) != nrow(x)) { - cli::cli_abort("{.arg weights} should have {nrow(x)} elements.") - } - } - - res <- list( - formula = form, - data = x, - weights = weights, - x_var = x_var - ) - res -} - -make_formula <- function(x, short = TRUE) { - y_part <- "~" - if (short) { - form_text <- paste0(y_part, ".") - } else { - form_text <- paste0(y_part, paste0(x, collapse = "+")) - } - as.formula(form_text) -} - #' @param object An object of class [`cluster_fit`]. #' @inheritParams predict.cluster_fit #' @rdname convert_helpers @@ -262,13 +198,3 @@ make_formula <- function(x, short = TRUE) { } list(x = new_data, offset = offset) } - -#' @rdname convert_helpers -#' @keywords internal -.convert_x_to_form_new <- function(object, new_data) { - new_data <- new_data[, object$x_var, drop = FALSE] - if (!is.data.frame(new_data)) { - new_data <- as.data.frame(new_data) - } - new_data -} diff --git a/R/engine_docs.R b/R/engine_docs.R index dc276b5..d6f93c7 100644 --- a/R/engine_docs.R +++ b/R/engine_docs.R @@ -1,3 +1,4 @@ +# nocov start # https://github.com/tidymodels/parsnip/blob/main/R/engine_docs.R #' Knit engine-specific documentation @@ -64,3 +65,4 @@ list_md_problems <- function() { map(md_files, get_errors) |> vctrs::vec_rbind() } +# nocov end diff --git a/R/fit.R b/R/fit.R index 5cf4f8b..942c8ab 100644 --- a/R/fit.R +++ b/R/fit.R @@ -120,48 +120,29 @@ fit.cluster_spec <- function( eval_env$data <- data eval_env$formula <- formula - fit_interface <- - check_interface(eval_env$formula, eval_env$data, cl, object) + check_interface(eval_env$formula, eval_env$data, cl, object) # populate `method` with the details for this model type object <- add_methods(object, engine = object$engine) check_installs(object) - interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_") - - # Now call the wrappers that transition between the interface - # called here ("fit" interface) that will direct traffic to - # what the underlying model uses. For example, if a formula is - # used here, `fit_interface_formula` will determine if a - # translation has to be made if the model interface is x/y/ - res <- - switch( - interfaces, - # homogeneous combinations: - formula_formula = form_form( - object = object, - control = control, - env = eval_env - ), - - # heterogenous combinations - formula_matrix = form_x( - object = object, - control = control, - env = eval_env, - target = object$method$fit$interface, - ... - ), - formula_data.frame = form_x( - object = object, - control = control, - env = eval_env, - target = object$method$fit$interface, - ... - ), - cli::cli_abort("{interfaces} is unknown.") + interface <- object$method$fit$interface + if (!interface %in% c("matrix", "data.frame")) { + # nocov start + cli::cli_abort( + "Column interface {.val {interface}} is not supported.", + .internal = TRUE ) + } # nocov end + + res <- form_x( + object = object, + control = control, + env = eval_env, + target = interface, + ... + ) model_classes <- class(res$fit) class(res) <- c(paste0("_", model_classes[1]), "cluster_fit") res <- modelenv::new_unsupervised_fit(res) @@ -177,7 +158,7 @@ check_interface <- function(formula, data, cl, model) { if (form_interface) { return("formula") } - cli::cli_abort("Error when checking the interface.") + cli::cli_abort("Error when checking the interface.", .internal = TRUE) # nocov } inher <- function(x, cls, cl) { @@ -249,51 +230,29 @@ fit_xy.cluster_spec <- cl <- match.call(expand.dots = TRUE) eval_env <- rlang::env() eval_env$x <- x - fit_interface <- check_x_interface(eval_env$x, cl, object) + check_x_interface(eval_env$x, cl, object) # populate `method` with the details for this model type object <- add_methods(object, engine = object$engine) check_installs(object) - interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_") - - # Now call the wrappers that transition between the interface - # called here ("fit" interface) that will direct traffic to - # what the underlying model uses. For example, if a formula is - # used here, `fit_interface_formula` will determine if a - # translation has to be made if the model interface is x/y/ - res <- - switch( - interfaces, - # homogeneous combinations: - matrix_matrix = , - data.frame_matrix = x_x( - object = object, - env = eval_env, - control = control, - target = "matrix", - ... - ), - data.frame_data.frame = , - matrix_data.frame = x_x( - object = object, - env = eval_env, - control = control, - target = "data.frame", - ... - ), - - # heterogenous combinations - matrix_formula = , - data.frame_formula = x_form( - object = object, - env = eval_env, - control = control, - ... - ), - cli::cli_abort("{interfaces} is unknown.") + interface <- object$method$fit$interface + if (!interface %in% c("matrix", "data.frame")) { + # nocov start + cli::cli_abort( + "Column interface {.val {interface}} is not supported.", + .internal = TRUE ) + } # nocov end + + res <- x_x( + object = object, + env = eval_env, + control = control, + target = interface, + ... + ) model_classes <- class(res$fit) class(res) <- c(paste0("_", model_classes[1]), "cluster_fit") res @@ -328,7 +287,7 @@ check_x_interface <- function(x, cl, model) { if (df_interface) { return("data.frame") } - cli::cli_abort("Error when checking the interface") + cli::cli_abort("Error when checking the interface", .internal = TRUE) # nocov } allow_sparse <- function(x) { diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 49dd704..605fa0c 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -1,43 +1,5 @@ # https://github.com/tidymodels/parsnip/blob/main/R/fit_helpers.R -form_form <- function(object, control, env, ...) { - # evaluate quoted args once here to check them - object <- check_args(object) - - # sub in arguments to actual syntax for corresponding engine - object <- translate_tidyclust(object, engine = object$engine) - - fit_call <- make_form_call(object, env = env) - - res <- list( - spec = object - ) - - if (control$verbosity > 1L) { - elapsed <- system.time( - res$fit <- eval_mod( - fit_call, - capture = control$verbosity == 0, - catch = control$catch, - env = env, - ... - ), - gcFirst = FALSE - ) - } else { - res$fit <- eval_mod( - fit_call, - capture = control$verbosity == 0, - catch = control$catch, - env = env, - ... - ) - elapsed <- list(elapsed = NA_real_) - } - res$elapsed <- elapsed - res -} - form_x <- function(object, control, env, target = "none", ...) { encoding_info <- modelenv::get_encoding(class(object)[1]) |> @@ -116,30 +78,3 @@ x_x <- function(object, env, control, target = "none", y = NULL, ...) { res$elapsed <- elapsed res } - -x_form <- function(object, env, control, ...) { - encoding_info <- - modelenv::get_encoding(class(object)[1]) |> - dplyr::filter(mode == object$mode, engine == object$engine) - - remove_intercept <- encoding_info |> dplyr::pull(remove_intercept) - - data_obj <- - .convert_x_to_form_fit( - x = env$x, - weights = NULL, - remove_intercept = remove_intercept - ) - env$formula <- data_obj$formula - env$data <- data_obj$data - - # which terms etc goes in the preproc slot here? - res <- form_form( - object = object, - env = env, - control = control, - ... - ) - res$preproc <- data_obj[c("x_var")] - res -} diff --git a/R/metric-aaa.R b/R/metric-aaa.R index 0de27d4..c9a097d 100644 --- a/R/metric-aaa.R +++ b/R/metric-aaa.R @@ -61,10 +61,13 @@ cluster_metric_set <- function(...) { if (fn_cls == "cluster_metric") { make_cluster_metric_function(fns) } else { + # nocov start cli::cli_abort( - "Internal error: {.fn validate_function_class} should have errored on - unknown classes." + "Internal error: {.fn validate_function_class} should have errored on + unknown classes.", + .internal = TRUE ) + # nocov end } } @@ -94,10 +97,13 @@ validate_inputs_are_functions <- function(fns) { get_quo_label <- function(quo) { out <- rlang::as_label(quo) if (length(out) != 1L) { + # nocov start cli::cli_abort( "Internal error: {.code as_label(quo)} resulted in a character vector - of length > 1." + of length > 1.", + .internal = TRUE ) + # nocov end } is_namespaced <- grepl("::", out, fixed = TRUE) if (is_namespaced) { diff --git a/R/predict.R b/R/predict.R index a43c9b4..9c721c4 100644 --- a/R/predict.R +++ b/R/predict.R @@ -141,13 +141,8 @@ prepare_data <- function(object, new_data) { fit_interface <- object$spec$method$fit$interface pp_names <- names(object$preproc) - if (any(pp_names == "terms") || any(pp_names == "x_var")) { - # Translation code - if (fit_interface == "formula") { - new_data <- .convert_x_to_form_new(object$preproc, new_data) - } else { - new_data <- .convert_form_to_x_new(object$preproc, new_data)$x - } + if (any(pp_names == "terms")) { + new_data <- .convert_form_to_x_new(object$preproc, new_data)$x } remove_intercept <- diff --git a/man/convert_helpers.Rd b/man/convert_helpers.Rd index df11407..e631ca8 100644 --- a/man/convert_helpers.Rd +++ b/man/convert_helpers.Rd @@ -2,9 +2,7 @@ % Please edit documentation in R/convert_data.R \name{.convert_form_to_x_fit} \alias{.convert_form_to_x_fit} -\alias{.convert_x_to_form_fit} \alias{.convert_form_to_x_new} -\alias{.convert_x_to_form_new} \title{Helper functions to convert between formula and matrix interface} \usage{ .convert_form_to_x_fit( @@ -17,16 +15,12 @@ remove_intercept = TRUE ) -.convert_x_to_form_fit(x, weights = NULL, remove_intercept = TRUE) - .convert_form_to_x_new( object, new_data, na.action = stats::na.pass, composition = "data.frame" ) - -.convert_x_to_form_new(object, new_data) } \arguments{ \item{formula}{An object of class \code{formula} (or one that can be coerced to @@ -50,12 +44,6 @@ should be returned as a \code{"matrix"} or a \code{"data.frame"}.} \item{remove_intercept}{A logical indicating whether to remove the intercept column after \code{model.matrix()} is finished.} -\item{x}{A matrix, sparse matrix, or data frame of predictors. Only some -models have support for sparse matrix input. See \code{modelenv::get_encoding()} -for details. \code{x} should have column names.} - -\item{weights}{A numeric vector containing the weights.} - \item{object}{An object of class \code{\link{cluster_fit}}.} \item{new_data}{A rectangular data object, such as a data frame.} @@ -67,13 +55,11 @@ are intended for developer use. For the most part, this emulates the internals of \code{lm()} (and also see the notes at https://developer.r-project.org/model-fitting-functions.html). -\code{.convert_form_to_x_fit()} and \code{.convert_x_to_form_fit()} are for when the -data are created for modeling. \code{.convert_form_to_x_fit()} saves both the -data objects as well as the objects needed when new data are predicted -(e.g. \code{terms}, etc.). +\code{.convert_form_to_x_fit()} is for when the data are created for modeling. +It saves both the data objects as well as the objects needed when new data +are predicted (e.g. \code{terms}, etc.). -\code{.convert_form_to_x_new()} and \code{.convert_x_to_form_new()} are used when new -samples are being predicted and only require the predictors to be -available. +\code{.convert_form_to_x_new()} is used when new samples are being predicted and +only requires the predictors to be available. } \keyword{internal} diff --git a/tests/testthat/_snaps/cluster_metric_set.md b/tests/testthat/_snaps/cluster_metric_set.md index c73e991..f7dcda8 100644 --- a/tests/testthat/_snaps/cluster_metric_set.md +++ b/tests/testthat/_snaps/cluster_metric_set.md @@ -43,3 +43,47 @@ Error in `cluster_metric_set()`: ! `sse_within_total` is not a cluster metric. Did you mean `sse_within_total`? +# cluster_metric_set() errors when empty + + Code + cluster_metric_set() + Condition + Error in `validate_not_empty()`: + ! `cluster_metric_set()` requires at least 1 function supplied to `...`. + +# cluster_metric_set() errors with non-functions + + Code + cluster_metric_set("not_a_function") + Condition + Error in `validate_inputs_are_functions()`: + ! All inputs to `cluster_metric_set()` must be functions. + i These inputs are not: 1. + +# print.cluster_metric_set() works + + Code + print(metrics) + Output + # A tibble: 2 x 3 + metric class direction + + 1 sse_total cluster_metric zero + 2 sse_ratio cluster_metric zero + +# new_cluster_metric() errors with non-function + + Code + new_cluster_metric("not a function", direction = "maximize") + Condition + Error in `new_cluster_metric()`: + ! `fn` must be a function. + +# new_cluster_metric() errors with invalid direction + + Code + new_cluster_metric(fn, direction = "invalid") + Condition + Error in `new_cluster_metric()`: + ! `direction` must be one of "maximize", "minimize", or "zero", not "invalid". + diff --git a/tests/testthat/_snaps/control.md b/tests/testthat/_snaps/control.md index 7577956..f75b2c6 100644 --- a/tests/testthat/_snaps/control.md +++ b/tests/testthat/_snaps/control.md @@ -22,3 +22,35 @@ Error in `control_cluster()`: ! catch should be a logical. +# print.control_cluster() works + + Code + print(control_cluster()) + Output + tidyclust control object + +--- + + Code + print(control_cluster(verbosity = 2)) + Output + tidyclust control object + - verbose level 2 + +--- + + Code + print(control_cluster(catch = TRUE)) + Output + tidyclust control object + - fit errors will be caught + +--- + + Code + print(control_cluster(verbosity = 2, catch = TRUE)) + Output + tidyclust control object + - verbose level 2 + - fit errors will be caught + diff --git a/tests/testthat/_snaps/convert_data.md b/tests/testthat/_snaps/convert_data.md new file mode 100644 index 0000000..7bb0d82 --- /dev/null +++ b/tests/testthat/_snaps/convert_data.md @@ -0,0 +1,33 @@ +# .convert_form_to_x_fit() errors on invalid composition + + Code + .convert_form_to_x_fit(~., mtcars, composition = "invalid") + Condition + Error in `.convert_form_to_x_fit()`: + ! `composition` should be or . + +# .convert_form_to_x_fit() errors on non-numeric weights + + Code + .convert_form_to_x_fit(~., mtcars, weights = letters[1:32]) + Condition + Error in `.convert_form_to_x_fit()`: + ! The `weights` must be a numeric vector. + +# .convert_form_to_x_fit() errors on invalid dots arguments + + Code + .convert_form_to_x_fit(~., mtcars, bad_arg = 1) + Condition + Error in `check_form_dots()`: + ! The argument `bad_arg` cannot be used to create the data. + i Possible arguments are: `subset` and `weights`. + +# .convert_form_to_x_new() errors on invalid composition + + Code + .convert_form_to_x_new(fit$preproc, mtcars, composition = "invalid") + Condition + Error in `.convert_form_to_x_new()`: + ! `composition` should be either `data.frame` or `matrix`. + diff --git a/tests/testthat/_snaps/fitting.md b/tests/testthat/_snaps/fitting.md index 37426af..e80aaed 100644 --- a/tests/testthat/_snaps/fitting.md +++ b/tests/testthat/_snaps/fitting.md @@ -23,3 +23,54 @@ ! `num_clusters` must be at most the number of distinct data points (4). i `num_clusters` was set to 10. +# fit() errors when mode is unknown + + Code + fit(spec, ~., data = mtcars) + Condition + Error in `fit()`: + ! Please set the mode in the model specification. + +# fit() uses default engine when not set and verbosity > 0 + + Code + fit <- fit(k_means(num_clusters = 3, engine = NULL), ~., data = mtcars[1:10, ], + control = control_cluster(verbosity = 1)) + Condition + Warning: + Engine set to `stats`. + +# fit_xy() uses default engine when not set and verbosity > 0 + + Code + fit <- fit_xy(k_means(num_clusters = 3, engine = NULL), mtcars[1:10, ], + control = control_cluster(verbosity = 1)) + Condition + Warning: + Engine set to `stats`. + +# fit() errors when called with x and y arguments + + Code + fit(set_engine(k_means(num_clusters = 3), "stats"), ~., data = mtcars, x = mtcars, + y = mtcars$mpg) + Condition + Error in `fit()`: + ! The `fit.cluster_spec()` function is for the formula methods. Use `fit_xy()` instead. + +# fit() errors when formula is not a formula + + Code + fit(set_engine(k_means(num_clusters = 3), "stats"), "not a formula", data = mtcars) + Condition + Error in `inher()`: + ! `formula` should be a string. + +# fit_xy() errors when x has no column names + + Code + fit_xy(set_engine(k_means(num_clusters = 3), "stats"), mat) + Condition + Error in `fit_xy()`: + ! `x` should have column names. + diff --git a/tests/testthat/_snaps/hier_clust.md b/tests/testthat/_snaps/hier_clust.md index 1ad05e8..e77d68b 100644 --- a/tests/testthat/_snaps/hier_clust.md +++ b/tests/testthat/_snaps/hier_clust.md @@ -72,3 +72,12 @@ Computational engine: stats +# check_args.hier_clust() errors on negative num_clusters + + Code + spec <- set_engine(hier_clust(num_clusters = -1), "stats") + fit(spec, ~., data = mtcars) + Condition + Error in `check_args()`: + ! The number of centers should be >= 0. + diff --git a/tests/testthat/_snaps/predict_raw.md b/tests/testthat/_snaps/predict_raw.md new file mode 100644 index 0000000..18b0823 --- /dev/null +++ b/tests/testthat/_snaps/predict_raw.md @@ -0,0 +1,8 @@ +# predict_raw() warns for try-error fit + + Code + res <- predict_raw(fit, mtcars[1:5, ]) + Condition + Warning: + Cluster fit failed; cannot make predictions. + diff --git a/tests/testthat/_snaps/print.md b/tests/testthat/_snaps/print.md index 921d4ac..432808f 100644 --- a/tests/testthat/_snaps/print.md +++ b/tests/testthat/_snaps/print.md @@ -85,3 +85,78 @@ Number of objects: 32 +# print.cluster_fit() shows elapsed time when verbosity > 1 + + Code + print(fit) + Output + tidyclust cluster object + + Fit time: + K-means clustering with 3 clusters of sizes 7, 11, 14 + + Cluster means: + mpg cyl disp hp drat wt qsec vs + 1 19.74286 6 183.3143 122.28571 3.585714 3.117143 17.97714 0.5714286 + 3 26.66364 4 105.1364 82.63636 4.070909 2.285727 19.13727 0.9090909 + 2 15.10000 8 353.1000 209.21429 3.229286 3.999214 16.77214 0.0000000 + am gear carb + 1 0.4285714 3.857143 3.428571 + 3 0.7272727 4.090909 1.545455 + 2 0.1428571 3.285714 3.500000 + + Clustering vector: + Mazda RX4 Mazda RX4 Wag Datsun 710 Hornet 4 Drive + 1 1 2 1 + Hornet Sportabout Valiant Duster 360 Merc 240D + 3 1 3 2 + Merc 230 Merc 280 Merc 280C Merc 450SE + 2 1 1 3 + Merc 450SL Merc 450SLC Cadillac Fleetwood Lincoln Continental + 3 3 3 3 + Chrysler Imperial Fiat 128 Honda Civic Toyota Corolla + 3 2 2 2 + Toyota Corona Dodge Challenger AMC Javelin Camaro Z28 + 2 3 3 3 + Pontiac Firebird Fiat X1-9 Porsche 914-2 Lotus Europa + 3 2 2 2 + Ford Pantera L Ferrari Dino Maserati Bora Volvo 142E + 3 1 3 2 + + Within cluster sum of squares by cluster: + [1] 13954.34 11848.37 93643.90 + (between_SS / total_SS = 80.8 %) + + Available components: + + [1] "cluster" "centers" "totss" "withinss" "tot.withinss" + [6] "betweenss" "size" "iter" "ifault" + +# print.cluster_fit() handles try-error + + Code + print(fit) + Output + tidyclust cluster object + + Cluster fit failed with error: + Error in try(stop("intentional error for testing"), silent = TRUE) : + intentional error for testing + + +# print.cluster_spec() works with translated spec + + Code + print(spec) + Output + K Means Cluster Specification (partition) + + Main Arguments: + num_clusters = 3 + + Computational engine: stats + + Model fit template: + tidyclust::.k_means_fit_stats(x = missing_arg(), centers = missing_arg(), + centers = 3) + diff --git a/tests/testthat/test-cluster_metric_set.R b/tests/testthat/test-cluster_metric_set.R index 0e75863..c22fce9 100644 --- a/tests/testthat/test-cluster_metric_set.R +++ b/tests/testthat/test-cluster_metric_set.R @@ -64,3 +64,53 @@ test_that("cluster_metric_set() errors with advice for some functions", { cluster_metric_set(sse_within) ) }) + +test_that("cluster_metric_set() errors when empty", { + expect_snapshot(error = TRUE, cluster_metric_set()) +}) + +test_that("cluster_metric_set() errors with non-functions", { + expect_snapshot(error = TRUE, cluster_metric_set("not_a_function")) +}) + +test_that("print.cluster_metric_set() works", { + metrics <- cluster_metric_set(sse_total, sse_ratio) + expect_snapshot(print(metrics)) +}) + +test_that("cluster_metric_set() works with namespaced functions", { + metrics <- cluster_metric_set(tidyclust::sse_total, tidyclust::sse_ratio) + + kmeans_fit <- k_means(num_clusters = 3) |> + set_engine("stats") |> + fit(~., mtcars) + + res <- metrics(kmeans_fit, new_data = mtcars) + expect_equal(res$.metric, c("sse_total", "sse_ratio")) +}) + +test_that("new_cluster_metric() works", { + fn <- function(object, new_data = NULL) { + tibble::tibble(.metric = "test", .estimator = "standard", .estimate = 1) + } + + metric <- new_cluster_metric(fn, direction = "maximize") + + expect_s3_class(metric, "cluster_metric") + expect_equal(attr(metric, "direction"), "maximize") +}) + +test_that("new_cluster_metric() errors with non-function", { + expect_snapshot( + error = TRUE, + new_cluster_metric("not a function", direction = "maximize") + ) +}) + +test_that("new_cluster_metric() errors with invalid direction", { + fn <- function(object, new_data = NULL) 1 + expect_snapshot( + error = TRUE, + new_cluster_metric(fn, direction = "invalid") + ) +}) diff --git a/tests/testthat/test-control.R b/tests/testthat/test-control.R index 43e8a58..ba2529c 100644 --- a/tests/testthat/test-control.R +++ b/tests/testthat/test-control.R @@ -31,3 +31,13 @@ test_that("control_cluster() error with wrong input", { control_cluster(catch = "yes") ) }) + +test_that("print.control_cluster() works", { + expect_snapshot(print(control_cluster())) + + expect_snapshot(print(control_cluster(verbosity = 2))) + + expect_snapshot(print(control_cluster(catch = TRUE))) + + expect_snapshot(print(control_cluster(verbosity = 2, catch = TRUE))) +}) diff --git a/tests/testthat/test-convert_data.R b/tests/testthat/test-convert_data.R new file mode 100644 index 0000000..f4185fd --- /dev/null +++ b/tests/testthat/test-convert_data.R @@ -0,0 +1,138 @@ +# .convert_form_to_x_fit() tests ----------------------------------------------- + +test_that(".convert_form_to_x_fit() errors on invalid composition", { + expect_snapshot( + error = TRUE, + .convert_form_to_x_fit(~., mtcars, composition = "invalid") + ) +}) + +test_that(".convert_form_to_x_fit() errors on non-numeric weights", { + expect_snapshot( + error = TRUE, + .convert_form_to_x_fit(~., mtcars, weights = letters[1:32]) + ) +}) + +test_that(".convert_form_to_x_fit() errors on invalid dots arguments", { + expect_snapshot( + error = TRUE, + .convert_form_to_x_fit(~., mtcars, bad_arg = 1) + ) +}) + +test_that(".convert_form_to_x_fit() processes subset argument", { + res <- .convert_form_to_x_fit(~., mtcars, subset = 1:10) + expect_equal(nrow(res$x), 10) +}) + +test_that(".convert_form_to_x_fit() works with one_hot indicators", { + skip("contr_one_hot not available in tidyclust") + + data <- data.frame(x = 1:10, cat = factor(rep(c("a", "b"), 5))) + res <- .convert_form_to_x_fit(~., data, indicators = "one_hot") + + expect_named(res$x, c("x", "cata", "catb")) +}) + +test_that(".convert_form_to_x_fit() works with indicators = none", { + data <- data.frame(x = 1:10, cat = factor(rep(c("a", "b"), 5))) + res <- .convert_form_to_x_fit(~., data, indicators = "none") + + expect_s3_class(res$x, "data.frame") + expect_named(res$x, c("x", "cat")) +}) + +test_that(".convert_form_to_x_fit() returns matrix when requested", { + res <- .convert_form_to_x_fit(~., mtcars, composition = "matrix") + expect_true(is.matrix(res$x)) +}) + +test_that(".convert_form_to_x_fit() returns data.frame by default", { + res <- .convert_form_to_x_fit(~., mtcars, composition = "data.frame") + expect_s3_class(res$x, "data.frame") +}) + +test_that(".convert_form_to_x_fit() removes intercept column by default", { + res <- .convert_form_to_x_fit(~., mtcars, remove_intercept = TRUE) + expect_false("(Intercept)" %in% colnames(res$x)) +}) + +test_that(".convert_form_to_x_fit() keeps intercept when requested", { + res <- .convert_form_to_x_fit(~., mtcars, remove_intercept = FALSE) + expect_named(res$x, c("(Intercept)", names(mtcars))) +}) + +test_that(".convert_form_to_x_fit() returns terms object", { + res <- .convert_form_to_x_fit(~., mtcars) + expect_s3_class(res$terms, "terms") +}) + +test_that(".convert_form_to_x_fit() returns xlevels for factors", { + data <- data.frame(x = 1:10, cat = factor(rep(c("a", "b"), 5))) + res <- .convert_form_to_x_fit(~., data) + expect_named(res$xlevels, "cat") +}) + +test_that(".convert_form_to_x_fit() accepts valid weights", { + res <- .convert_form_to_x_fit(~., mtcars, weights = rep(1, 32)) + expect_equal(res$weights, rep(1, 32)) +}) + +test_that(".convert_form_to_x_fit() stores options", { + res <- .convert_form_to_x_fit( + ~., + mtcars, + indicators = "traditional", + composition = "matrix", + remove_intercept = FALSE + ) + expect_equal(res$options$indicators, "traditional") + expect_equal(res$options$composition, "matrix") + expect_false(res$options$remove_intercept) +}) + +# .convert_form_to_x_new() tests ----------------------------------------------- + +test_that(".convert_form_to_x_new() errors on invalid composition", { + fit <- k_means(num_clusters = 3) |> set_engine("stats") |> fit(~., mtcars) + + expect_snapshot( + error = TRUE, + .convert_form_to_x_new(fit$preproc, mtcars, composition = "invalid") + ) +}) + +test_that(".convert_form_to_x_new() works with matrix composition", { + fit <- k_means(num_clusters = 3) |> set_engine("stats") |> fit(~., mtcars) + + res <- .convert_form_to_x_new( + fit$preproc, + mtcars[1:5, ], + composition = "matrix" + ) + expect_true(is.matrix(res$x)) + expect_equal(nrow(res$x), 5) +}) + +test_that(".convert_form_to_x_new() works with data.frame composition", { + fit <- k_means(num_clusters = 3) |> set_engine("stats") |> fit(~., mtcars) + + res <- .convert_form_to_x_new( + fit$preproc, + mtcars[1:5, ], + composition = "data.frame" + ) + expect_s3_class(res$x, "data.frame") +}) + +test_that(".convert_form_to_x_new() works with one_hot indicators", { + skip("contr_one_hot not available in tidyclust") + + data <- data.frame(x = 1:10, cat = factor(rep(c("a", "b"), 5))) + + preproc <- .convert_form_to_x_fit(~., data, indicators = "one_hot") + res <- .convert_form_to_x_new(preproc, data[1:3, ]) + + expect_named(res$x, c("x", "cata", "catb")) +}) diff --git a/tests/testthat/test-fitting.R b/tests/testthat/test-fitting.R index 4aa01b9..d3d62b5 100644 --- a/tests/testthat/test-fitting.R +++ b/tests/testthat/test-fitting.R @@ -51,3 +51,72 @@ test_that("hier_clust() with cut_height = 0 produces n clusters", { expect_identical(length(unique(res$.pred_cluster)), nrow(mtcars)) }) + +test_that("fit() errors when mode is unknown", { + spec <- k_means(num_clusters = 3) + spec$mode <- "unknown" + + expect_snapshot( + error = TRUE, + fit(spec, ~., data = mtcars) + ) +}) + +test_that("fit() uses default engine when not set and verbosity > 0", { + expect_snapshot( + fit <- k_means(num_clusters = 3, engine = NULL) |> + fit(~., data = mtcars[1:10, ], control = control_cluster(verbosity = 1)) + ) + expect_s3_class(fit, "cluster_fit") + expect_equal(fit$spec$engine, "stats") +}) + +test_that("fit_xy() uses default engine when not set and verbosity > 0", { + expect_snapshot( + fit <- k_means(num_clusters = 3, engine = NULL) |> + fit_xy(mtcars[1:10, ], control = control_cluster(verbosity = 1)) + ) + expect_s3_class(fit, "cluster_fit") + expect_equal(fit$spec$engine, "stats") +}) + +test_that("fit() errors when called with x and y arguments", { + expect_snapshot( + error = TRUE, + k_means(num_clusters = 3) |> + set_engine("stats") |> + fit(~., data = mtcars, x = mtcars, y = mtcars$mpg) + ) +}) + +test_that("fit() errors when formula is not a formula", { + expect_snapshot( + error = TRUE, + k_means(num_clusters = 3) |> + set_engine("stats") |> + fit("not a formula", data = mtcars) + ) +}) + +test_that("fit_xy() errors when x has no column names", { + mat <- as.matrix(mtcars) + colnames(mat) <- NULL + + expect_snapshot( + error = TRUE, + k_means(num_clusters = 3) |> + set_engine("stats") |> + fit_xy(mat) + ) +}) + +test_that("fit_xy() works with matrix input", { + mat <- as.matrix(mtcars) + + fit <- k_means(num_clusters = 3) |> + set_engine("stats") |> + fit_xy(mat) + + expect_s3_class(fit, "cluster_fit") +}) + diff --git a/tests/testthat/test-hier_clust.R b/tests/testthat/test-hier_clust.R index edef78e..370204e 100644 --- a/tests/testthat/test-hier_clust.R +++ b/tests/testthat/test-hier_clust.R @@ -191,6 +191,32 @@ test_that("updating", { ) }) +test_that("update.hier_clust() works with parameters tibble", { + spec <- hier_clust(num_clusters = 3) + params <- tibble::tibble(num_clusters = 5) + + updated <- update(spec, parameters = params) + expect_equal(rlang::eval_tidy(updated$args$num_clusters), 5) +}) + +test_that("update.hier_clust() works with fresh = TRUE", { + spec <- hier_clust(num_clusters = 3, linkage_method = "single") + + updated <- update(spec, num_clusters = 7, fresh = TRUE) + expect_equal(rlang::eval_tidy(updated$args$num_clusters), 7) + expect_null(rlang::eval_tidy(updated$args$linkage_method)) +}) + +test_that("check_args.hier_clust() errors on negative num_clusters", { + expect_snapshot( + error = TRUE, + { + spec <- hier_clust(num_clusters = -1) |> set_engine("stats") + fit(spec, ~., data = mtcars) + } + ) +}) + test_that("reordering is done correctly for stats hier_clust", { set.seed(42) diff --git a/tests/testthat/test-k_means.R b/tests/testthat/test-k_means.R index cb8d153..a17d172 100644 --- a/tests/testthat/test-k_means.R +++ b/tests/testthat/test-k_means.R @@ -120,6 +120,29 @@ test_that("updating", { ) }) +test_that("update.k_means() works with parameters tibble", { + spec <- k_means(num_clusters = 3) + params <- tibble::tibble(num_clusters = 5) + + updated <- update(spec, parameters = params) + expect_equal(rlang::eval_tidy(updated$args$num_clusters), 5) +}) + +test_that("update.k_means() works with fresh = TRUE", { + spec <- k_means(num_clusters = 3) |> set_engine("stats", nstart = 5) + + updated <- update(spec, num_clusters = 7, fresh = TRUE) + expect_equal(rlang::eval_tidy(updated$args$num_clusters), 7) + expect_null(rlang::eval_tidy(updated$eng_args$nstart)) +}) + +test_that("update.k_means() works with engine args", { + spec <- k_means(num_clusters = 3) |> set_engine("stats", nstart = 5) + + updated <- update(spec, nstart = 10) + expect_equal(rlang::eval_tidy(updated$eng_args$nstart), 10) +}) + test_that("Engine-specific arguments are passed to ClusterR models", { spec <- k_means(num_clusters = 2) |> set_engine("ClusterR", fuzzy = FALSE) diff --git a/tests/testthat/test-predict_raw.R b/tests/testthat/test-predict_raw.R new file mode 100644 index 0000000..e23a574 --- /dev/null +++ b/tests/testthat/test-predict_raw.R @@ -0,0 +1,15 @@ +test_that("predict_raw() warns for try-error fit", { + spec <- k_means(num_clusters = 3) |> set_engine("stats") + fit <- fit(spec, ~., data = mtcars) + fit$fit <- try(stop("intentional error"), silent = TRUE) + + # Manually set up the method to allow raw predictions for testing + + fit$spec$method$pred$raw <- list( + args = list(object = rlang::expr(object$fit)), + func = c(fun = "predict") + ) + + expect_snapshot(res <- predict_raw(fit, mtcars[1:5, ])) + expect_null(res) +}) diff --git a/tests/testthat/test-print.R b/tests/testthat/test-print.R index 331c85c..6a7d0a0 100644 --- a/tests/testthat/test-print.R +++ b/tests/testthat/test-print.R @@ -30,3 +30,29 @@ test_that("print.cluster_fit() works for hier_clust", { print(fit) ) }) + +test_that("print.cluster_fit() shows elapsed time when verbosity > 1", { + set.seed(1234) + spec <- k_means(num_clusters = 3) |> set_engine("stats") + ctrl <- control_cluster(verbosity = 2) + fit <- fit(spec, ~., data = mtcars, control = ctrl) + + scrub_time <- function(x) gsub("Fit time: .*", "Fit time: ", x) + expect_snapshot(print(fit), transform = scrub_time) +}) + +test_that("print.cluster_fit() handles try-error", { + set.seed(1234) + spec <- k_means(num_clusters = 3) |> set_engine("stats") + fit <- fit(spec, ~., data = mtcars) + fit$fit <- try(stop("intentional error for testing"), silent = TRUE) + + expect_snapshot(print(fit)) +}) + +test_that("print.cluster_spec() works with translated spec", { + spec <- k_means(num_clusters = 3) |> set_engine("stats") + spec <- translate_tidyclust(spec) + + expect_snapshot(print(spec)) +})