From ef32853ee157f2e173f1e2718a05bdedd0d64110 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 5 Mar 2026 15:16:52 -0800 Subject: [PATCH 1/4] switch to using loop_call() from tune itself --- R/tune_cluster.R | 705 ++++++----------------------------------------- 1 file changed, 83 insertions(+), 622 deletions(-) diff --git a/R/tune_cluster.R b/R/tune_cluster.R index a522fca..eab0bfe 100644 --- a/R/tune_cluster.R +++ b/R/tune_cluster.R @@ -168,17 +168,16 @@ tune_cluster_workflow <- function( pset = pset ) - # Save rset attributes, then fall back to a bare tibble + # Save rset attributes rset_info <- tune::pull_rset_attributes(resamples) - resamples <- new_bare_tibble(resamples) resamples <- tune_cluster_loop( resamples = resamples, grid = grid, workflow = workflow, + param_info = pset, metrics = metrics, - control = control, - rng = rng + control = control ) if (is_cataclysmic(resamples)) { @@ -201,657 +200,119 @@ tune_cluster_workflow <- function( ) } +# ------------------------------------------------------------------------------ + tune_cluster_loop <- function( resamples, grid, workflow, + param_info, metrics, - control, - rng -) { - `%op%` <- get_operator(control$allow_par, workflow) - `%:%` <- foreach::`%:%` - - packages <- c(control$pkgs, required_pkgs(workflow)) - - grid_info <- compute_grid_info(workflow, grid) - - n_resamples <- nrow(resamples) - iterations <- seq_len(n_resamples) - - n_grid_info <- nrow(grid_info) - rows <- seq_len(n_grid_info) - - splits <- resamples$splits - - parallel_over <- control$parallel_over - parallel_over <- parallel_over_finalize(parallel_over, n_resamples) - - rlang::local_options(doFuture.rng.onMisuse = "ignore") - - if (identical(parallel_over, "resamples")) { - seeds <- generate_seeds(rng, n_resamples) - - # We are wrapping in `local()` here because `fn_tune_grid_loop_iter_safely()` adds - # on.exit/deferred handlers to our execution frame by passing `tune_env$progress_env` - # to cli's progress bar constructor, which then adds an exit handler on that - # environment. Because `%op%` evaluates its expression in `eval()` in the calling - # environment (i.e. here), the handlers are added in the special frame environment - # created by `eval()`. This causes the handler to run much too early. By evaluating in - # a local environment, we prevent `defer()`/`on.exit()` from finding the short-lived - # context of `%op%`. Instead it looks all the way up here to register the handler. - - results <- local({ - suppressPackageStartupMessages( - foreach::foreach( - split = splits, - seed = seeds, - .packages = packages, - .errorhandling = "pass" - ) %op% - { - # Extract internal function from tune namespace - tune_cluster_loop_iter_safely <- utils::getFromNamespace( - x = "tune_cluster_loop_iter_safely", - ns = "tidyclust" - ) - - tune_cluster_loop_iter_safely( - split = split, - grid_info = grid_info, - workflow = workflow, - metrics = metrics, - control = control, - seed = seed - ) - } - ) - }) - } else if (identical(parallel_over, "everything")) { - seeds <- generate_seeds(rng, n_resamples * n_grid_info) - - results <- local( - suppressPackageStartupMessages( - foreach::foreach( - iteration = iterations, - split = splits, - .packages = packages, - .errorhandling = "pass" - ) %:% - foreach::foreach( - row = rows, - seed = slice_seeds(seeds, iteration, n_grid_info), - .packages = packages, - .errorhandling = "pass", - .combine = iter_combine - ) %op% - { - # Extract internal function from tidyclust namespace - tune_grid_loop_iter_safely <- utils::getFromNamespace( - x = "tune_cluster_loop_iter_safely", - ns = "tidyclust" - ) - - grid_info_row <- vctrs::vec_slice(grid_info, row) - - tune_grid_loop_iter_safely( - split = split, - grid_info = grid_info_row, - workflow = workflow, - metrics = metrics, - control = control, - seed = seed - ) - } - ) - ) - } else { - cli::cli_abort("Internal error: Invalid {.arg parallel_over}.") - } - - resamples <- pull_metrics(resamples, results, control) - resamples <- pull_notes(resamples, results, control) - resamples <- pull_extracts(resamples, results, control) - resamples <- pull_predictions(resamples, results, control) - resamples -} - -# https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R#L299 -compute_grid_info <- function(workflow, grid) { + control +) { if (is.null(grid)) { - out <- new_grid_info_resamples() - return(out) - } - grid <- tibble::as_tibble(grid) - parameters <- hardhat::extract_parameter_set_dials(workflow) - parameters_model <- dplyr::filter(parameters, source == "cluster_spec") - parameters_preprocessor <- dplyr::filter(parameters, source == "recipe") - any_parameters_model <- nrow(parameters_model) > 0 - any_parameters_preprocessor <- nrow(parameters_preprocessor) > 0 - if (any_parameters_model) { - if (any_parameters_preprocessor) { - compute_grid_info_model_and_preprocessor( - workflow, - grid, - parameters_model - ) - } else { - compute_grid_info_model(workflow, grid, parameters_model) - } - } else { - if (any_parameters_preprocessor) { - compute_grid_info_preprocessor(workflow, grid, parameters_model) - } else { - cli::cli_abort( - c( - "Internal error: {.code workflow} should have some tunable parameters - if {.code grid} is not {.code NULL}." - ) - ) - } + grid <- tibble::tibble() } -} -tune_cluster_loop_iter <- function( - split, - grid_info, - workflow, - metrics, - control, - seed -) { - load_pkgs(workflow) - load_namespace(control$pkgs) - - # After package loading to avoid potential package RNG manipulation - if (!is.null(seed)) { - # `assign()`-ing the random seed alters the `kind` type to L'Ecuyer-CMRG, - # so we have to ensure it is restored on exit - old_kind <- RNGkind()[[1]] - assign(".Random.seed", seed, envir = globalenv()) - on.exit(RNGkind(kind = old_kind), add = TRUE) - } + control <- tune::.update_parallel_over(control, resamples, grid) - control_parsnip <- parsnip::control_parsnip(verbosity = 0, catch = TRUE) - control_workflow <- workflows::control_workflow(control_parsnip) + # Determine how to process the tasks + strategy <- tune::choose_framework(workflow, control) - event_level <- control$event_level + # Generate parallel seeds + resamples$.seeds <- tune::get_parallel_seeds(nrow(resamples)) - out_metrics <- NULL - out_extracts <- NULL - out_predictions <- NULL - out_notes <- tibble::tibble( - location = character(0), - type = character(0), - note = character(0) - ) + # Save rset attributes + rset_info <- tune::pull_rset_attributes(resamples) + split_args <- rsample::.get_split_args(resamples) - params <- hardhat::extract_parameter_set_dials(workflow) - model_params <- dplyr::filter(params, source == "cluster_spec") - preprocessor_params <- dplyr::filter(params, source == "recipe") - - param_names <- dplyr::pull(params, "id") - model_param_names <- dplyr::pull(model_params, "id") - preprocessor_param_names <- dplyr::pull(preprocessor_params, "id") - - # Model related grid-info columns - cols <- rlang::expr( - c( - .iter_model, - .iter_config, - .msg_model, - dplyr::all_of(model_param_names), - .submodels - ) + resamples <- new_bare_tibble(resamples) + resamples <- vec_list_rowwise(resamples) + + # Package loading + tm_pkgs <- c( + "rsample", + "workflows", + "hardhat", + "tidyclust", + "parsnip", + "tune" ) + load_pkgs <- c(required_pkgs(workflow), control$pkgs, tm_pkgs) + load_pkgs <- unique(load_pkgs) - # Nest grid_info: - # - Preprocessor info in the outer level - # - Model info in the inner level - grid_info <- tidyr::nest(grid_info, data = !!cols) - - training <- rsample::analysis(split) - - # ---------------------------------------------------------------------------- - # Preprocessor loop - - iter_preprocessors <- grid_info[[".iter_preprocessor"]] - - workflow_original <- workflow - - for (iter_preprocessor in iter_preprocessors) { - workflow <- workflow_original - - iter_grid_info <- dplyr::filter( - .data = grid_info, - .iter_preprocessor == iter_preprocessor - ) - - iter_grid_preprocessor <- dplyr::select( - .data = iter_grid_info, - dplyr::all_of(preprocessor_param_names) - ) - - iter_msg_preprocessor <- iter_grid_info[[".msg_preprocessor"]] - - workflow <- tune::finalize_workflow_preprocessor( - workflow = workflow, - grid_preprocessor = iter_grid_preprocessor - ) - - workflow <- catch_and_log( - .expr = workflows::.fit_pre(workflow, training), - control, - split, - iter_msg_preprocessor, - notes = out_notes + is_inst <- purrr::map_lgl(load_pkgs, rlang::is_installed) + if (any(!is_inst)) { + nms <- load_pkgs[!is_inst] + cli::cli_abort( + "Some package installs are needed: {.pkg {nms}}", + call = NULL ) + } - if (is_failure(workflow)) { - next - } - - # -------------------------------------------------------------------------- - # Model loop - - iter_grid_info_models <- iter_grid_info[["data"]][[1L]] - iter_models <- iter_grid_info_models[[".iter_model"]] - - workflow_preprocessed <- workflow - - for (iter_model in iter_models) { - workflow <- workflow_preprocessed - - iter_grid_info_model <- dplyr::filter( - .data = iter_grid_info_models, - .iter_model == iter_model - ) - - iter_grid_model <- dplyr::select( - .data = iter_grid_info_model, - dplyr::all_of(model_param_names) - ) - - iter_submodels <- iter_grid_info_model[[".submodels"]][[1L]] - iter_msg_model <- iter_grid_info_model[[".msg_model"]] - iter_config <- iter_grid_info_model[[".iter_config"]][[1L]] - - workflow <- finalize_workflow_spec(workflow, iter_grid_model) - - workflow <- catch_and_log_fit( - expr = workflows::.fit_model(workflow, control_workflow), - control, - split, - iter_msg_model, - notes = out_notes - ) - - # Check for parsnip level and model level failure - if (is_failure(workflow) || is_failure(workflow$fit$fit$fit)) { - next - } - - workflow <- workflows::.fit_finalize(workflow) - - # FIXME: I think this might be wrong? Doesn't use submodel parameters, - # so `extracts` column doesn't list the correct parameters. - iter_grid <- dplyr::bind_cols( - iter_grid_preprocessor, - iter_grid_model - ) - - # FIXME: bind_cols() drops number of rows with zero col data frames - # because of a bug with vec_cbind() - # https://github.com/r-lib/vctrs/issues/1281 - if (ncol(iter_grid_preprocessor) == 0L && ncol(iter_grid_model) == 0L) { - nrow <- nrow(iter_grid_model) - iter_grid <- tibble::new_tibble(x = list(), nrow = nrow) - } - - out_extracts <- append_extracts( - collection = out_extracts, - workflow = workflow, - grid = iter_grid, - split = split, - ctrl = control, - .config = iter_config - ) - - iter_msg_predictions <- paste(iter_msg_model, "(predictions)") - - iter_predictions <- catch_and_log( - predict_model(split, workflow, iter_grid, metrics, iter_submodels), - control, - split, - iter_msg_predictions, - bad_only = TRUE, - notes = out_notes - ) - - # Check for prediction level failure - if (is_failure(iter_predictions)) { - next - } - - out_metrics <- append_metrics( - workflow = workflow, - collection = out_metrics, - predictions = iter_predictions, - metrics = metrics, - param_names = param_names, - event_level = event_level, - split = split, - .config = iter_config - ) - - iter_config_metrics <- extract_metrics_config(param_names, out_metrics) - - out_predictions <- append_predictions( - collection = out_predictions, - predictions = iter_predictions, - split = split, - control = control, - .config = iter_config_metrics - ) - } # model loop - } # preprocessor loop - - list( - .metrics = out_metrics, - .extracts = out_extracts, - .predictions = out_predictions, - .notes = out_notes - ) -} - -tune_cluster_loop_iter_safely <- function( - split, - grid_info, - workflow, - metrics, - control, - seed -) { - tune_cluster_loop_iter_wrapper <- super_safely(tune_cluster_loop_iter) + par_opt <- list() - time <- proc.time() - result <- tune_cluster_loop_iter_wrapper( - split, - grid_info, + # Create static object using tune's make_static + static <- tune::make_static( workflow, - metrics, - control, - seed + param_info = param_info, + grid = grid, + metrics = metrics, + eval_time = NULL, + split_args = split_args, + control = control, + pkgs = load_pkgs, + strategy = strategy ) - new_time <- proc.time() - - # Update with elapsed time - result$result[[".elapsed"]] <- new_time["elapsed"] - time["elapsed"] - error <- result$error - warnings <- result$warnings - result <- result$result - - # No problems - if (is.null(error) && length(warnings) == 0L) { - return(result) + # Handle parallel_over = "everything" + if (control$parallel_over == "everything") { + candidates <- purrr::map(seq_len(nrow(grid)), \(i) grid[i, ]) + inds <- tidyr::crossing(s = seq_along(candidates), b = seq_along(resamples)) + inds <- vec_list_rowwise(inds) } - # No errors, but we might have warning notes - if (is.null(error)) { - res <- result - notes <- result$.notes - } else { - res <- error - notes <- NULL - } + # Use tune's loop_call directly + cl <- tune:::loop_call(control$parallel_over, strategy, par_opt) + res <- rlang::eval_bare(cl) - problems <- list(res = res, signals = warnings) + # Process results + res <- dplyr::bind_rows(res) - notes <- log_problems(notes, control, split, "internal", problems) + resamples <- dplyr::bind_rows(resamples) + id_cols <- grep("^id", names(resamples), value = TRUE) - # Need an output template - if (!is.null(error)) { - result <- list( - .metrics = NULL, - .extracts = NULL, - .predictions = NULL, - .notes = NULL - ) + if (control$parallel_over == "resamples") { + res <- dplyr::full_join(resamples, res, by = id_cols) + } else { + res <- res |> + dplyr::summarize( + dplyr::across(dplyr::matches("^\\."), ~ list(purrr::list_rbind(.x))), + .by = c(!!!id_cols) + ) |> + dplyr::full_join(resamples, by = id_cols) } - # Update with new notes - result[[".notes"]] <- notes - - result -} - -# https://github.com/tidymodels/tune/blob/main/R/grid_code_paths.R#L542 -super_safely <- function(fn) { - warnings <- list() - handle_error <- function(e) { - e <- structure(e$message, class = "try-error", condition = e) - list(result = NULL, error = e, warnings = warnings) - } - handle_warning <- function(w) { - warnings <<- c(warnings, list(w)) - rlang::cnd_muffle(w) - } - safe_fn <- function(...) { - withCallingHandlers( - expr = tryCatch( - expr = list( - result = fn(...), - error = NULL, - warnings = warnings - ), - error = handle_error - ), - warning = handle_warning + res <- res |> + dplyr::select(-dplyr::any_of(".seeds")) |> + dplyr::select(-dplyr::any_of("outcome_names")) |> + dplyr::relocate( + splits, + dplyr::starts_with("id"), + .metrics, + .notes, + dplyr::any_of(".extracts") ) - } - safe_fn -} -# https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R#L428 -compute_grid_info_model <- function(workflow, grid, parameters_model) { - spec <- extract_spec_parsnip(workflow) - out <- min_grid(spec, grid) - n_fit_models <- nrow(out) - seq_fit_models <- seq_len(n_fit_models) - msgs_preprocessor <- new_msgs_preprocessor(i = 1L, n = 1L) - msgs_preprocessor <- rep(msgs_preprocessor, times = n_fit_models) - msgs_model <- new_msgs_model( - i = seq_fit_models, - n = n_fit_models, - msgs_preprocessor = msgs_preprocessor - ) - iter_configs <- compute_config_ids(out, "Preprocessor1") - out <- tibble::add_column( - .data = out, - .iter_preprocessor = 1L, - .before = 1L - ) - out <- tibble::add_column( - .data = out, - .msg_preprocessor = msgs_preprocessor, - .after = ".iter_preprocessor" - ) - out <- tibble::add_column( - .data = out, - .iter_model = seq_fit_models, - .after = ".msg_preprocessor" - ) - out <- tibble::add_column( - .data = out, - .iter_config = iter_configs, - .after = ".iter_model" - ) - out <- tibble::add_column( - .data = out, - .msg_model = msgs_model, - .after = ".iter_config" - ) - out + res } -# https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R#L484 -compute_grid_info_model_and_preprocessor <- function( - workflow, - grid, - parameters_model -) { - parameter_names_model <- parameters_model[["id"]] - - # Nest model parameters, keep preprocessor parameters outside - out <- tidyr::nest(grid, data = dplyr::all_of(parameter_names_model)) - - n_preprocessors <- nrow(out) - seq_preprocessors <- seq_len(n_preprocessors) - - # preprocessor / - msgs_preprocessor <- new_msgs_preprocessor( - i = seq_preprocessors, - n = n_preprocessors - ) - - out <- tibble::add_column( - .data = out, - .iter_preprocessor = seq_preprocessors, - .before = 1L - ) - - out <- tibble::add_column( - .data = out, - .msg_preprocessor = msgs_preprocessor, - .after = ".iter_preprocessor" - ) - - spec <- extract_spec_parsnip(workflow) - - ids_preprocessor <- format_with_padding(seq_preprocessors) - ids_preprocessor <- paste0("Preprocessor", ids_preprocessor) - - model_grids <- out[["data"]] - - for (i in seq_preprocessors) { - model_grid <- model_grids[[i]] - - model_grid <- min_grid(spec, model_grid) - - n_fit_models <- nrow(model_grid) - seq_fit_models <- seq_len(n_fit_models) - - msg_preprocessor <- msgs_preprocessor[[i]] - id_preprocessor <- ids_preprocessor[[i]] - - # preprocessor /, model / - msgs_model <- new_msgs_model( - i = seq_fit_models, - n = n_fit_models, - msgs_preprocessor = msg_preprocessor - ) - - # Preprocessor_Model - iter_configs <- compute_config_ids(model_grid, id_preprocessor) - - model_grid <- tibble::add_column( - .data = model_grid, - .iter_model = seq_fit_models, - .before = 1L - ) - - model_grid <- tibble::add_column( - .data = model_grid, - .iter_config = iter_configs, - .after = ".iter_model" - ) - - model_grid <- tibble::add_column( - .data = model_grid, - .msg_model = msgs_model, - .after = ".iter_config" - ) - - model_grids[[i]] <- model_grid - } - - out[["data"]] <- model_grids - - # Unnest to match other grid-info generators - out <- tidyr::unnest(out, data) - - out +vec_list_rowwise <- function(x) { + vctrs::vec_split(x, by = 1:nrow(x))$val } -# https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R#L359 -compute_grid_info_preprocessor <- function(workflow, grid, parameters_model) { - out <- grid - - n_preprocessors <- nrow(out) - seq_preprocessors <- seq_len(n_preprocessors) - - # Preprocessor_Model1 - ids <- format_with_padding(seq_preprocessors) - iter_configs <- paste0("Preprocessor", ids, "_Model1") - iter_configs <- as.list(iter_configs) - - # preprocessor / - msgs_preprocessor <- new_msgs_preprocessor( - i = seq_preprocessors, - n = n_preprocessors - ) - - # preprocessor /, model 1/1 - msgs_model <- new_msgs_model( - i = 1L, - n = 1L, - msgs_preprocessor = msgs_preprocessor - ) - - # Manually add .submodels column, which will always have empty lists - submodels <- rep_len(list(list()), n_preprocessors) - - out <- tibble::add_column( - .data = out, - .iter_preprocessor = seq_preprocessors, - .before = 1L - ) - - out <- tibble::add_column( - .data = out, - .msg_preprocessor = msgs_preprocessor, - .after = ".iter_preprocessor" - ) - - # Add at the end - out <- tibble::add_column( - .data = out, - .iter_model = 1L, - .after = NULL - ) - - out <- tibble::add_column( - .data = out, - .iter_config = iter_configs, - .after = ".iter_model" - ) - - out <- tibble::add_column( - .data = out, - .msg_model = msgs_model, - .after = ".iter_config" - ) - - out <- tibble::add_column( - .data = out, - .submodels = submodels, - .after = ".msg_model" - ) - - out -} +# ------------------------------------------------------------------------------ # https://github.com/tidymodels/tune/blob/main/R/checks.R#L338 check_metrics <- function(x, object) { @@ -920,7 +381,7 @@ check_parameters <- function( msg <- paste0(msg, "s: ", paste0("'", unk_names, "'", collapse = ", ")) } - tune_log(list(verbose = TRUE), split = NULL, msg, type = "info") + cli::cli_inform(msg) x <- workflows::.fit_pre(workflow, data)$pre$mold$predictors pset$object <- map(pset$object, dials::finalize, x = x) From 548e84e26b2508fe97aff6690e67ef69ef9e2efa Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 5 Mar 2026 15:17:03 -0800 Subject: [PATCH 2/4] delete unused helpers --- R/append.R | 90 ----------- R/pull.R | 81 ---------- R/tune_helpers.R | 398 ----------------------------------------------- 3 files changed, 569 deletions(-) delete mode 100644 R/append.R delete mode 100644 R/pull.R diff --git a/R/append.R b/R/append.R deleted file mode 100644 index 553c5bd..0000000 --- a/R/append.R +++ /dev/null @@ -1,90 +0,0 @@ -# https://github.com/tidymodels/tune/blob/main/R/pull.R#L136 -append_predictions <- function( - collection, - predictions, - split, - control, - .config = NULL -) { - if (!control$save_pred) { - return(NULL) - } - if (inherits(predictions, "try-error")) { - return(collection) - } - - predictions <- vctrs::vec_cbind(predictions, labels(split)) - - if (!rlang::is_null(.config)) { - by <- setdiff(names(.config), ".config") - - if (length(by) == 0L) { - # Nothing to tune, just bind on config - predictions <- vctrs::vec_cbind(predictions, .config) - } else { - predictions <- dplyr::inner_join(predictions, .config, by = by) - } - } - - dplyr::bind_rows(collection, predictions) -} - -append_metrics <- function( - workflow, - collection, - predictions, - metrics, - param_names, - event_level, - split, - .config = NULL -) { - if (inherits(predictions, "try-error")) { - return(collection) - } - - params <- predictions |> - dplyr::select(dplyr::all_of(param_names)) |> - dplyr::distinct() - - tmp_est <- metrics(workflow, new_data = rsample::analysis(split)) - - tmp_est <- cbind(tmp_est, labels(split)) - - tmp_est <- cbind(params, tmp_est) - if (!rlang::is_null(.config)) { - tmp_est <- cbind(tmp_est, .config) - } - dplyr::bind_rows(collection, tmp_est) -} - -append_extracts <- function( - collection, - workflow, - grid, - split, - ctrl, - .config = NULL -) { - extracts <- - grid |> - dplyr::bind_cols(labels(split)) |> - dplyr::mutate( - .extracts = list( - extract_details(workflow, ctrl$extract) - ) - ) - - if (!rlang::is_null(.config)) { - extracts <- cbind(extracts, .config) - } - - dplyr::bind_rows(collection, extracts) -} - -extract_details <- function(object, extractor) { - if (is.null(extractor)) { - return(list()) - } - try(extractor(object), silent = TRUE) -} diff --git a/R/pull.R b/R/pull.R deleted file mode 100644 index 5c010e9..0000000 --- a/R/pull.R +++ /dev/null @@ -1,81 +0,0 @@ -pull_metrics <- function(resamples, res, control) { - out <- pulley(resamples, res, ".metrics") - out$.metrics <- maybe_repair(out$.metrics) - out -} - -pull_notes <- function(resamples, res, control) { - resamples$.notes <- map(res, `[[`, ".notes") - resamples -} - -pull_extracts <- function(resamples, res, control) { - if (!is.null(control$extract)) { - resamples <- pulley(resamples, res, ".extracts") - } - resamples -} - -pull_predictions <- function(resamples, res, control) { - if (control$save_pred) { - resamples <- pulley(resamples, res, ".predictions") - resamples$.predictions <- maybe_repair(resamples$.predictions) - } - resamples -} - -# ------------------------------------------------------------------------------ - -# Grab the new results, make sure that they align row-wise with the rsample -# object and then bind columns -pulley <- function(resamples, res, col) { - if (all(map_lgl(res, inherits, "simpleError"))) { - res <- - resamples |> - dplyr::mutate(col = map(splits, \(x) NULL)) |> - stats::setNames(c(names(resamples), col)) - return(res) - } - - id_cols <- grep("^id", names(resamples), value = TRUE) - resamples <- dplyr::arrange(resamples, !!!rlang::syms(id_cols)) - pulled_vals <- dplyr::bind_rows(map(res, \(x) x[[col]])) - - if (nrow(pulled_vals) == 0) { - res <- - resamples |> - dplyr::mutate(col = map(splits, \(x) NULL)) |> - stats::setNames(c(names(resamples), col)) - return(res) - } - - pulled_vals <- tidyr::nest(pulled_vals, data = -dplyr::starts_with("id")) - names(pulled_vals)[ncol(pulled_vals)] <- col - - res <- new_bare_tibble(resamples) - res <- dplyr::full_join(res, pulled_vals, by = id_cols) - res <- reup_rs(resamples, res) - res -} - -maybe_repair <- function(x) { - not_null <- !map_lgl(x, is.null) - is_tibb <- map_lgl(x, tibble::is_tibble) - ok <- not_null & is_tibb - if (!any(ok)) { - return(x) - } - - good_val <- which(ok)[1] - template <- x[[good_val]][0, ] - - insert_val <- function(x, y) { - if (is.null(x)) { - x <- y - } - x - } - - x <- map(x, insert_val, y = template) - x -} diff --git a/R/tune_helpers.R b/R/tune_helpers.R index 3a9455c..b687e39 100644 --- a/R/tune_helpers.R +++ b/R/tune_helpers.R @@ -65,87 +65,6 @@ new_tune_results <- function( ) } -# https://github.com/tidymodels/tune/blob/main/R/parallel.R -get_operator <- function(allow = TRUE, object) { - is_par <- foreach::getDoParWorkers() > 1 - pkgs <- required_pkgs(object) - blacklist <- c("keras", "rJava") - if (is_par && allow && any(pkgs %in% blacklist)) { - pkgs <- pkgs[pkgs %in% blacklist] - msg <- paste0("'", pkgs, "'", collapse = ", ") - msg <- paste( - "Some required packages prohibit parallel processing: ", - msg - ) - cli::cli_alert_warning(msg) - allow <- FALSE - } - cond <- allow && is_par - if (cond) { - res <- foreach::`%dopar%` - } else { - res <- foreach::`%do%` - } - res -} - -# https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R -new_grid_info_resamples <- function() { - msgs_preprocessor <- new_msgs_preprocessor(i = 1L, n = 1L) - msgs_model <- new_msgs_model( - i = 1L, - n = 1L, - msgs_preprocessor = msgs_preprocessor - ) - iter_config <- list("Preprocessor1_Model1") - out <- tibble::tibble( - .iter_preprocessor = 1L, - .msg_preprocessor = msgs_preprocessor, - .iter_model = 1L, - .iter_config = iter_config, - .msg_model = msgs_model, - .submodels = list(list()) - ) - out -} - -new_msgs_preprocessor <- function(i, n) { - paste0("preprocessor ", i, "/", n) -} - -new_msgs_model <- function(i, n, msgs_preprocessor) { - paste0(msgs_preprocessor, ", model ", i, "/", n) -} - -# https://github.com/tidymodels/tune/blob/main/R/grid_code_paths.R -parallel_over_finalize <- function(parallel_over, n_resamples) { - if (!is.null(parallel_over)) { - return(parallel_over) - } - if (n_resamples == 1L) { - "everything" - } else { - "resamples" - } -} - -# https://github.com/tidymodels/tune/blob/main/R/grid_code_paths.R -generate_seeds <- function(rng, n) { - out <- vector("list", length = n) - if (!rng) { - return(out) - } - original_algorithms <- RNGkind(kind = "L'Ecuyer-CMRG") - original_rng_algorithm <- original_algorithms[[1]] - on.exit(RNGkind(kind = original_rng_algorithm), add = TRUE) - seed <- .Random.seed - for (i in seq_len(n)) { - out[[i]] <- seed - seed <- parallel::nextRNGStream(seed) - } - out -} - # https://github.com/tidymodels/tune/blob/main/R/min_grid.R #' Determine the minimum set of model fits @@ -168,153 +87,10 @@ blank_submodels <- function(grid) { dplyr::mutate_if(is.factor, as.character) } -# https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R -compute_config_ids <- function(data, id_preprocessor) { - submodels <- tidyr::unnest(data, .submodels, keep_empty = TRUE) - submodels <- dplyr::pull(submodels, .submodels) - model_sizes <- lengths(submodels) + 1L - n_total_models <- sum(model_sizes) - ids <- format_with_padding(seq_len(n_total_models)) - ids <- paste0(id_preprocessor, "_Model", ids) - n_fit_models <- nrow(data) - out <- vector("list", length = n_fit_models) - start <- 1L - for (i in seq_len(n_fit_models)) { - size <- model_sizes[[i]] - stop <- start + size - 1L - out[[i]] <- ids[rlang::seq2(start, stop)] - start <- stop + 1L - } - out -} - -format_with_padding <- function(x) { - gsub(" ", "0", format(x)) -} - -set_workflow_recipe <- function(workflow, recipe) { - workflow$pre$actions$recipe$recipe <- recipe - workflow -} - -catch_and_log <- function(.expr, ..., bad_only = FALSE, notes) { - tune_log(..., type = "info") - tmp <- catcher(.expr) - new_notes <- log_problems(notes, ..., tmp, bad_only = bad_only) - assign("out_notes", new_notes, envir = parent.frame()) - tmp$res -} - -tune_log <- function(control, split = NULL, task, type = "success") { - if (!control$verbose) { - return(invisible(NULL)) - } - if (!is.null(split)) { - labs <- labels(split) - labs <- rev(unlist(labs)) - labs <- paste0(labs, collapse = ", ") - labs <- paste0(labs, ": ") - } else { - labs <- "" - } - task <- gsub("\\{", "", task) - siren(paste0(labs, task), type = type) - NULL -} - -catcher <- function(expr) { - signals <- list() - add_cond <- function(cnd) { - signals <<- append(signals, list(cnd)) - rlang::cnd_muffle(cnd) - } - res <- try(withCallingHandlers(warning = add_cond, expr), silent = TRUE) - list(res = res, signals = signals) -} - -siren <- function(x, type = "info") { - tidyclust_color <- get_tidyclust_colors() - types <- names(tidyclust_color$message) - type <- match.arg(type, types) - msg <- glue::glue(x) - symb <- dplyr::case_when( - type == "warning" ~ tidyclust_color$symbol$warning("!"), - type == "go" ~ tidyclust_color$symbol$go(cli::symbol$pointer), - type == "danger" ~ tidyclust_color$symbol$danger("x"), - type == "success" ~ - tidyclust_color$symbol$success(tidyclust_symbol$success), - type == "info" ~ tidyclust_color$symbol$info("i") - ) - msg <- dplyr::case_when( - type == "warning" ~ tidyclust_color$message$warning(msg), - type == "go" ~ tidyclust_color$message$go(msg), - type == "danger" ~ tidyclust_color$message$danger(msg), - type == "success" ~ tidyclust_color$message$success(msg), - type == "info" ~ tidyclust_color$message$info(msg) - ) - if (inherits(msg, "character")) { - msg <- as.character(msg) - } - message(paste(symb, msg)) -} - -log_problems <- function(notes, control, split, loc, res, bad_only = FALSE) { - control2 <- control - control2$verbose <- TRUE - wrn <- res$signals - if (length(wrn) > 0) { - wrn_msg <- map_chr(wrn, \(x) x$message) - wrn_msg <- unique(wrn_msg) - wrn_msg <- paste(wrn_msg, collapse = ", ") - wrn_msg <- tibble::tibble( - location = loc, - type = "warning", - note = wrn_msg - ) - notes <- dplyr::bind_rows(notes, wrn_msg) - wrn_msg <- glue::glue_collapse( - paste0(loc, ": ", wrn_msg$note), - width = options()$width - 5 - ) - tune_log(control2, split, wrn_msg, type = "warning") - } - if (inherits(res$res, "try-error")) { - err_msg <- as.character(attr(res$res, "condition")) - err_msg <- gsub("\n$", "", err_msg) - err_msg <- tibble::tibble( - location = loc, - type = "error", - note = err_msg - ) - notes <- dplyr::bind_rows(notes, err_msg) - err_msg <- glue::glue_collapse( - paste0(loc, ": ", err_msg$note), - width = options()$width - 5 - ) - tune_log(control2, split, err_msg, type = "danger") - } else { - if (!bad_only) { - tune_log(control, split, loc, type = "success") - } - } - notes -} - is_failure <- function(x) { inherits(x, "try-error") } -# https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R -finalize_workflow_spec <- function(workflow, grid_model) { - if (ncol(grid_model) == 0L) { - return(workflow) - } - spec <- extract_spec_parsnip(workflow) - spec <- merge(spec, grid_model)$x[[1]] - workflow <- set_workflow_spec(workflow, spec) - workflow -} - #' @export merge.cluster_spec <- function(x, y, ...) { merger(x, y, ...) @@ -386,127 +162,6 @@ update_recipe <- function(grid, object, pset, step_id, nms, ...) { object } -catch_and_log_fit <- function(expr, ..., notes) { - tune_log(..., type = "info") - caught <- catcher(expr) - result <- caught$res - if (is_failure(result)) { - result_parsnip <- list(res = result, signals = list()) - new_notes <- log_problems(notes, ..., result_parsnip) - assign("out_notes", new_notes, envir = parent.frame()) - return(result) - } - if (!is_workflow(result)) { - cli::cli_abort("Internal error: Model result is not a workflow!") - } - fit <- result$fit$fit$fit - if (is_failure(fit)) { - result_fit <- list(res = fit, signals = list()) - new_notes <- log_problems(notes, ..., result_fit) - assign("out_notes", new_notes, envir = parent.frame()) - return(result) - } - new_notes <- log_problems(notes, ..., caught) - assign("out_notes", new_notes, envir = parent.frame()) - result -} - -# https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R -predict_model <- function(split, workflow, grid, metrics, submodels = NULL) { - model <- extract_fit_parsnip(workflow) - - forged <- forge_from_workflow(split, workflow) - - x_vals <- forged$predictors - - orig_rows <- as.integer(split, data = "assessment") - - if (length(orig_rows) != nrow(x_vals)) { - msg <- paste0( - "Some assessment set rows are not available at ", - "prediction time. " - ) - - if (has_preprocessor_recipe(workflow)) { - msg <- paste0( - msg, - "Consider using `skip = TRUE` on any recipe steps that remove rows ", - "to avoid calling them on the assessment set." - ) - } else { - msg <- paste0( - msg, - "Did your preprocessing steps filter or remove rows?" - ) - } - - cli::cli_abort(msg) - } - - # Determine the type of prediction that is required - types <- "cluster" - - res <- NULL - merge_vars <- c(".row", names(grid)) - - for (type_iter in types) { - # Regular predictions - tmp_res <- - stats::predict(model, x_vals, type = type_iter) |> - dplyr::mutate(.row = orig_rows) |> - cbind(grid, row.names = NULL) - - if (!is.null(submodels)) { - submod_length <- lengths(submodels) - has_submodels <- any(submod_length > 0) - - # if (has_submodels) { - # submod_param <- names(submodels) - # mp_call <- - # call2( - # "multi_predict", - # .ns = "parsnip", - # object = expr(model), - # new_data = expr(x_vals), - # type = type_iter, - # !!!make_submod_arg(grid, model, submodels) - # ) - # tmp_res <- - # eval_tidy(mp_call) |> - # mutate(.row = orig_rows) |> - # unnest(cols = dplyr::starts_with(".pred")) |> - # cbind(dplyr::select(grid, -dplyr::all_of(submod_param)), - # row.names = NULL) |> - # # go back to user-defined name - # dplyr::rename(!!!make_rename_arg(grid, model, submodels)) |> - # dplyr::select(dplyr::one_of(names(tmp_res))) |> - # dplyr::bind_rows(tmp_res) - # } - } - - if (!is.null(res)) { - res <- dplyr::full_join(res, tmp_res, by = merge_vars) - } else { - res <- tmp_res - } - - rm(tmp_res) - } # end type loop - - tibble::as_tibble(res) -} - -forge_from_workflow <- function(split, workflow) { - new_data <- rsample::assessment(split) - - blueprint <- workflow$pre$mold$blueprint - - # Can't use tune version since outcomes = FALSE - forged <- hardhat::forge(new_data, blueprint, outcomes = FALSE) - - forged -} - # ------------------------------------------------------------------------------ # https://github.com/tidymodels/tune/blob/main/R/grid_helpers.R#L613 @@ -542,59 +197,6 @@ set_workflow_recipe <- function(workflow, recipe) { workflow } -# ------------------------------------------------------------------------------ - -# https://github.com/tidymodels/tune/blob/main/R/pull.R#L210 -extract_metrics_config <- function(param_names, metrics) { - metrics_config_names <- c(param_names, ".config") - out <- metrics[metrics_config_names] - vctrs::vec_unique(out) -} - -# https://github.com/tidymodels/tune/blob/main/R/tune_bayes.R#L784 -# Make sure that rset object attributes are kept once joined -reup_rs <- function(resamples, res) { - sort_cols <- grep("^id", names(resamples), value = TRUE) - if (any(names(res) == ".iter")) { - sort_cols <- c(".iter", sort_cols) - } - res <- dplyr::arrange(res, !!!rlang::syms(sort_cols)) - att <- attributes(res) - rsample_att <- attributes(resamples) - for (i in names(rsample_att)) { - if (!any(names(att) == i)) { - attr(res, i) <- rsample_att[[i]] - } - } - - class(res) <- unique(c("tune_results", class(res))) - res -} - is_workflow <- function(x) { inherits(x, "workflow") } - -grid_msg <- "`grid` should be a positive integer or a data frame." - -slice_seeds <- function(x, i, n) { - x[(i - 1L) * n + seq_len(n)] -} - -iter_combine <- function(...) { - results <- list(...) - metrics <- map(results, \(.x) .x[[".metrics"]]) - extracts <- map(results, \(.x) .x[[".extracts"]]) - predictions <- map(results, \(.x) .x[[".predictions"]]) - notes <- map(results, \(.x) .x[[".notes"]]) - metrics <- vctrs::vec_c(!!!metrics) - extracts <- vctrs::vec_c(!!!extracts) - predictions <- vctrs::vec_c(!!!predictions) - notes <- vctrs::vec_c(!!!notes) - list( - .metrics = metrics, - .extracts = extracts, - .predictions = predictions, - .notes = notes - ) -} From 412503e94c5e6d6daa780aa2bceb5a626695fc34 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 5 Mar 2026 15:17:25 -0800 Subject: [PATCH 3/4] various small things --- DESCRIPTION | 4 +- R/aaa.R | 7 +- tests/testthat/_snaps/tune_cluster.md | 93 +++++++++++++-------------- tests/testthat/test-fitting.R | 1 - tests/testthat/test-tune_cluster.R | 12 ++-- 5 files changed, 54 insertions(+), 63 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 2a4a273..039a337 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -22,7 +22,6 @@ Imports: dials (>= 1.3.0), dplyr (>= 1.0.9), flexclust (>= 1.3-6), - foreach, generics (>= 0.1.2), glue (>= 1.6.2), hardhat (>= 1.0.0), @@ -30,6 +29,7 @@ Imports: parsnip (>= 1.0.2), philentropy (>= 0.9.0), prettyunits (>= 1.1.0), + purrr (>= 1.0.0), rlang (>= 1.0.6), rsample (>= 1.0.0), stats, @@ -57,4 +57,4 @@ Config/testthat/edition: 3 Config/usethis/last-upkeep: 2025-04-24 Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.2 +RoxygenNote: 7.3.3 diff --git a/R/aaa.R b/R/aaa.R index db45641..cdec94f 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -6,11 +6,8 @@ utils::globalVariables( ".", "..object", ".cluster", - ".iter_config", - ".iter_model", - ".iter_preprocessor", - ".msg_model", - ".submodels", + ".metrics", + ".notes", "call_info", "cluster", "component", diff --git a/tests/testthat/_snaps/tune_cluster.md b/tests/testthat/_snaps/tune_cluster.md index e830888..617dd99 100644 --- a/tests/testthat/_snaps/tune_cluster.md +++ b/tests/testthat/_snaps/tune_cluster.md @@ -5,71 +5,71 @@ metrics = metrics) Message i Fold1: preprocessor 1/3 - v Fold1: preprocessor 1/3 + i Fold1: preprocessor 1/3 (prediction data) i Fold1: preprocessor 1/3, model 1/3 - v Fold1: preprocessor 1/3, model 1/3 - i Fold1: preprocessor 1/3, model 1/3 (predictions) + i Fold1: preprocessor 1/3, model 1/3 (model metrics) + i Fold1: preprocessor 1/3, model 1/3 (extracts) i Fold1: preprocessor 1/3, model 2/3 - v Fold1: preprocessor 1/3, model 2/3 - i Fold1: preprocessor 1/3, model 2/3 (predictions) + i Fold1: preprocessor 1/3, model 2/3 (model metrics) + i Fold1: preprocessor 1/3, model 2/3 (extracts) i Fold1: preprocessor 1/3, model 3/3 - v Fold1: preprocessor 1/3, model 3/3 - i Fold1: preprocessor 1/3, model 3/3 (predictions) + i Fold1: preprocessor 1/3, model 3/3 (model metrics) + i Fold1: preprocessor 1/3, model 3/3 (extracts) i Fold1: preprocessor 2/3 - v Fold1: preprocessor 2/3 + i Fold1: preprocessor 2/3 (prediction data) i Fold1: preprocessor 2/3, model 1/3 - v Fold1: preprocessor 2/3, model 1/3 - i Fold1: preprocessor 2/3, model 1/3 (predictions) + i Fold1: preprocessor 2/3, model 1/3 (model metrics) + i Fold1: preprocessor 2/3, model 1/3 (extracts) i Fold1: preprocessor 2/3, model 2/3 - v Fold1: preprocessor 2/3, model 2/3 - i Fold1: preprocessor 2/3, model 2/3 (predictions) + i Fold1: preprocessor 2/3, model 2/3 (model metrics) + i Fold1: preprocessor 2/3, model 2/3 (extracts) i Fold1: preprocessor 2/3, model 3/3 - v Fold1: preprocessor 2/3, model 3/3 - i Fold1: preprocessor 2/3, model 3/3 (predictions) + i Fold1: preprocessor 2/3, model 3/3 (model metrics) + i Fold1: preprocessor 2/3, model 3/3 (extracts) i Fold1: preprocessor 3/3 - v Fold1: preprocessor 3/3 + i Fold1: preprocessor 3/3 (prediction data) i Fold1: preprocessor 3/3, model 1/3 - v Fold1: preprocessor 3/3, model 1/3 - i Fold1: preprocessor 3/3, model 1/3 (predictions) + i Fold1: preprocessor 3/3, model 1/3 (model metrics) + i Fold1: preprocessor 3/3, model 1/3 (extracts) i Fold1: preprocessor 3/3, model 2/3 - v Fold1: preprocessor 3/3, model 2/3 - i Fold1: preprocessor 3/3, model 2/3 (predictions) + i Fold1: preprocessor 3/3, model 2/3 (model metrics) + i Fold1: preprocessor 3/3, model 2/3 (extracts) i Fold1: preprocessor 3/3, model 3/3 - v Fold1: preprocessor 3/3, model 3/3 - i Fold1: preprocessor 3/3, model 3/3 (predictions) + i Fold1: preprocessor 3/3, model 3/3 (model metrics) + i Fold1: preprocessor 3/3, model 3/3 (extracts) i Fold2: preprocessor 1/3 - v Fold2: preprocessor 1/3 + i Fold2: preprocessor 1/3 (prediction data) i Fold2: preprocessor 1/3, model 1/3 - v Fold2: preprocessor 1/3, model 1/3 - i Fold2: preprocessor 1/3, model 1/3 (predictions) + i Fold2: preprocessor 1/3, model 1/3 (model metrics) + i Fold2: preprocessor 1/3, model 1/3 (extracts) i Fold2: preprocessor 1/3, model 2/3 - v Fold2: preprocessor 1/3, model 2/3 - i Fold2: preprocessor 1/3, model 2/3 (predictions) + i Fold2: preprocessor 1/3, model 2/3 (model metrics) + i Fold2: preprocessor 1/3, model 2/3 (extracts) i Fold2: preprocessor 1/3, model 3/3 - v Fold2: preprocessor 1/3, model 3/3 - i Fold2: preprocessor 1/3, model 3/3 (predictions) + i Fold2: preprocessor 1/3, model 3/3 (model metrics) + i Fold2: preprocessor 1/3, model 3/3 (extracts) i Fold2: preprocessor 2/3 - v Fold2: preprocessor 2/3 + i Fold2: preprocessor 2/3 (prediction data) i Fold2: preprocessor 2/3, model 1/3 - v Fold2: preprocessor 2/3, model 1/3 - i Fold2: preprocessor 2/3, model 1/3 (predictions) + i Fold2: preprocessor 2/3, model 1/3 (model metrics) + i Fold2: preprocessor 2/3, model 1/3 (extracts) i Fold2: preprocessor 2/3, model 2/3 - v Fold2: preprocessor 2/3, model 2/3 - i Fold2: preprocessor 2/3, model 2/3 (predictions) + i Fold2: preprocessor 2/3, model 2/3 (model metrics) + i Fold2: preprocessor 2/3, model 2/3 (extracts) i Fold2: preprocessor 2/3, model 3/3 - v Fold2: preprocessor 2/3, model 3/3 - i Fold2: preprocessor 2/3, model 3/3 (predictions) + i Fold2: preprocessor 2/3, model 3/3 (model metrics) + i Fold2: preprocessor 2/3, model 3/3 (extracts) i Fold2: preprocessor 3/3 - v Fold2: preprocessor 3/3 + i Fold2: preprocessor 3/3 (prediction data) i Fold2: preprocessor 3/3, model 1/3 - v Fold2: preprocessor 3/3, model 1/3 - i Fold2: preprocessor 3/3, model 1/3 (predictions) + i Fold2: preprocessor 3/3, model 1/3 (model metrics) + i Fold2: preprocessor 3/3, model 1/3 (extracts) i Fold2: preprocessor 3/3, model 2/3 - v Fold2: preprocessor 3/3, model 2/3 - i Fold2: preprocessor 3/3, model 2/3 (predictions) + i Fold2: preprocessor 3/3, model 2/3 (model metrics) + i Fold2: preprocessor 3/3, model 2/3 (extracts) i Fold2: preprocessor 3/3, model 3/3 - v Fold2: preprocessor 3/3, model 3/3 - i Fold2: preprocessor 3/3, model 3/3 (predictions) + i Fold2: preprocessor 3/3, model 3/3 (model metrics) + i Fold2: preprocessor 3/3, model 3/3 (extracts) # tune model only - failure in formula is caught elegantly @@ -79,10 +79,7 @@ 1 }, save_pred = TRUE)) Message - x Fold1: preprocessor 1/1: Error in `hardhat::mold()`: - ! The following predictor ... - x Fold2: preprocessor 1/1: Error in `hardhat::mold()`: - ! The following predictor ... + > A | error: The following predictor was not found in `data`: "z". Condition Warning: All models failed. @@ -119,8 +116,8 @@ # A tibble: 2 x 4 splits id .metrics .notes - 1 Fold1 - 2 Fold2 + 1 Fold1 + 2 Fold2 # select_best() and show_best() works diff --git a/tests/testthat/test-fitting.R b/tests/testthat/test-fitting.R index d3d62b5..3bd94cd 100644 --- a/tests/testthat/test-fitting.R +++ b/tests/testthat/test-fitting.R @@ -119,4 +119,3 @@ test_that("fit_xy() works with matrix input", { expect_s3_class(fit, "cluster_fit") }) - diff --git a/tests/testthat/test-tune_cluster.R b/tests/testthat/test-tune_cluster.R index 99a5e4b..dab440e 100644 --- a/tests/testthat/test-tune_cluster.R +++ b/tests/testthat/test-tune_cluster.R @@ -30,7 +30,7 @@ test_that("tune recipe only", { expect_equal(sum(res_est$.metric == "sse_total"), nrow(grid)) expect_equal(sum(res_est$.metric == "sse_within_total"), nrow(grid)) expect_equal(res_est$n, rep(2, nrow(grid) * 2)) - expect_false(identical(num_comp, expr(tune()))) + expect_false(identical(num_comp, rlang::expr(tune()))) expect_true(res_workflow$trained) }) @@ -67,7 +67,7 @@ test_that("tune model only (with recipe)", { expect_equal(sum(res_est$.metric == "sse_total"), nrow(grid)) expect_equal(sum(res_est$.metric == "sse_within_total"), nrow(grid)) expect_equal(res_est$n, rep(2, nrow(grid) * 2)) - expect_false(identical(num_clusters, expr(tune()))) + expect_false(identical(num_clusters, rlang::expr(tune()))) expect_true(res_workflow$trained) }) @@ -175,8 +175,8 @@ test_that("tune model and recipe", { expect_equal(sum(res_est$.metric == "sse_total"), nrow(grid)) expect_equal(sum(res_est$.metric == "sse_within_total"), nrow(grid)) expect_equal(res_est$n, rep(2, nrow(grid) * 2)) - expect_false(identical(num_clusters, expr(tune()))) - expect_false(identical(num_comp, expr(tune()))) + expect_false(identical(num_clusters, rlang::expr(tune()))) + expect_false(identical(num_comp, rlang::expr(tune()))) expect_true(res_workflow$trained) }) @@ -279,13 +279,11 @@ test_that("tune model only - failure in formula is caught elegantly", { note <- notes[[1]]$note extracts <- cars_res$.extracts - predictions <- cars_res$.predictions expect_length(notes, 2L) # formula failed - no models run expect_equal(extracts, list(NULL, NULL)) - expect_equal(predictions, list(NULL, NULL)) }) test_that("argument order gives errors for recipes", { @@ -377,7 +375,7 @@ test_that("tune recipe only", { expect_equal(nrow(res_est), nrow(grid)) expect_equal(sum(res_est$.metric == "sse_within_total"), nrow(grid)) expect_equal(res_est$n, rep(2, nrow(grid))) - expect_false(identical(num_comp, expr(tune()))) + expect_false(identical(num_comp, rlang::expr(tune()))) expect_true(res_workflow$trained) }) From e6d5c2d544dc20e36ceb4e54706246f1a17e9310 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 5 Mar 2026 15:22:04 -0800 Subject: [PATCH 4/4] update description --- DESCRIPTION | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 039a337..b3f2ead 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -35,7 +35,7 @@ Imports: stats, tibble (>= 3.1.0), tidyr (>= 1.2.0), - tune (>= 1.0.0), + tune (>= 2.0.1.9004), utils, vctrs (>= 0.5.0) Suggests: @@ -51,6 +51,8 @@ Suggests: rmarkdown, testthat (>= 3.0.0), workflows (>= 1.1.2) +Remotes: + tidymodels/tune#1141 Config/Needs/website: pkgdown, tidymodels, tidyverse, palmerpenguins, patchwork, ggforce, tidyverse/tidytemplate Config/testthat/edition: 3