From f30d775c5f253c6616f2dbc724145b5f61e9d673 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Fri, 20 Mar 2026 13:19:45 +0000 Subject: [PATCH 1/3] add test on "all rows of the data set are available" --- ...ver-all-stages-helpers-predict-all-types.R | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/tests/testthat/test-loop-over-all-stages-helpers-predict-all-types.R b/tests/testthat/test-loop-over-all-stages-helpers-predict-all-types.R index 5a159e2f..57867fe4 100644 --- a/tests/testthat/test-loop-over-all-stages-helpers-predict-all-types.R +++ b/tests/testthat/test-loop-over-all-stages-helpers-predict-all-types.R @@ -977,3 +977,77 @@ test_that("predict censored regression - submodels - no calibration", { ) expect_equal(nrow(res_dyn), nrow(assessment(cens$rs$splits[[1]]))) }) + +# ------------------------------------------------------------------------------ +# Row availability checks + +test_that("process_prediction_data errors when recipe drops rows", { + skip_if_not_installed("modeldata") + + cls <- make_post_data() + + filter_rec <- recipe(class ~ ., data = cls$data) |> + step_filter(non_linear_1 > 0, skip = FALSE) + + wflow <- workflow(filter_rec, logistic_reg()) + wflow_fit <- fit(wflow, cls$data) + + ctrl <- tune::control_grid() + data_1 <- tune:::.get_data_subsets(wflow, cls$rs$splits[[1]], cls$args) + + static <- tune:::make_static( + wflow, + param_info = wflow |> extract_parameter_set_dials(), + grid = tibble(), + metrics = metric_set(accuracy), + eval_time = NULL, + split_args = cls$args, + control = ctrl + ) + static <- tune:::update_static(static, data_1) + + expect_snapshot( + tune:::process_prediction_data(wflow_fit, static), + error = TRUE + ) +}) + +test_that("process_prediction_data errors when non-recipe preprocessor drops rows", { + skip_if_not_installed("modeldata") + + cls <- make_post_data() + + wflow <- workflow(class ~ ., logistic_reg()) + wflow_fit <- fit(wflow, cls$data) + + ctrl <- tune::control_grid() + data_1 <- tune:::.get_data_subsets(wflow, cls$rs$splits[[1]], cls$args) + + static <- tune:::make_static( + wflow, + param_info = wflow |> extract_parameter_set_dials(), + grid = tibble(), + metrics = metric_set(accuracy), + eval_time = NULL, + split_args = cls$args, + control = ctrl + ) + static <- tune:::update_static(static, data_1) + + local_mocked_bindings( + forge_from_workflow = function(new_data, workflow) { + res <- hardhat::forge( + new_data, + workflow$pre$mold$blueprint, + outcomes = TRUE + ) + res$predictors <- res$predictors[1:5, ] + res + } + ) + + expect_snapshot( + tune:::process_prediction_data(wflow_fit, static), + error = TRUE + ) +}) From 1e1c6bfa0dbaac2f0b2fc959db024688024e113d Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Fri, 20 Mar 2026 13:35:39 +0000 Subject: [PATCH 2/3] bring back check on "all rows are available" --- R/grid_helpers.R | 26 ------------------- R/loop_over_all_stages-helpers.R | 18 +++++++++++-- ...er-all-stages-helpers-predict-all-types.md | 18 +++++++++++++ 3 files changed, 34 insertions(+), 28 deletions(-) create mode 100644 tests/testthat/_snaps/loop-over-all-stages-helpers-predict-all-types.md diff --git a/R/grid_helpers.R b/R/grid_helpers.R index c64747fa..7d3c26ef 100644 --- a/R/grid_helpers.R +++ b/R/grid_helpers.R @@ -1,29 +1,3 @@ -# TODO old predict_model bits to check into -# # TODO patch since parsnip does not record the column names when Surv objects -# # are used with fit_xy() -# if (model$spec$mode == "censored regression") { -# model$preproc$y_var <- names(y_vals) -# } -# -# if (length(orig_rows) != nrow(x_vals)) { -# msg <- "Some assessment set rows are not available at prediction time." -# -# if (has_preprocessor_recipe(workflow)) { -# msg <- -# c( -# msg, -# i = -# "Consider using {.code skip = TRUE} on any recipe steps that -# remove rows to avoid calling them on the assessment set." -# -# ) -# } else { -# msg <- c(msg, i = "Did your preprocessing steps filter or remove rows?") -# } -# -# cli::cli_abort(msg) -# } - trim_ipcw <- function(x) { x$.weight_time <- NULL x$.pred_censored <- NULL diff --git a/R/loop_over_all_stages-helpers.R b/R/loop_over_all_stages-helpers.R index bc8be390..d0077149 100644 --- a/R/loop_over_all_stages-helpers.R +++ b/R/loop_over_all_stages-helpers.R @@ -225,8 +225,6 @@ predict_all_types <- function( model_fit <- wflow_fit |> hardhat::extract_fit_parsnip() - # TODO tune::predict_model has some pre-prediction checks - sub_param <- names(submodel_grid) # Convert argument names to parsnip format see #1011 @@ -345,6 +343,22 @@ process_prediction_data <- function( .ind <- static$data[[source]]$ind processed_data <- forge_from_workflow(.data, wflow_fit) + + if (nrow(processed_data$predictors) != nrow(.data)) { + set_name <- if (source == "pred") "assessment" else "calibration" + msg <- "Some {set_name} set rows are not available at prediction time." + if (has_preprocessor_recipe(wflow_fit)) { + msg <- c( + msg, + i = "Consider using {.code skip = TRUE} on any recipe steps that + remove rows to avoid calling them on the {set_name} set." + ) + } else { + msg <- c(msg, i = "Did your preprocessing steps filter or remove rows?") + } + cli::cli_abort(msg) + } + processed_data$outcomes <- processed_data$outcomes |> dplyr::mutate(.row = .ind) processed_data diff --git a/tests/testthat/_snaps/loop-over-all-stages-helpers-predict-all-types.md b/tests/testthat/_snaps/loop-over-all-stages-helpers-predict-all-types.md new file mode 100644 index 00000000..830e3e70 --- /dev/null +++ b/tests/testthat/_snaps/loop-over-all-stages-helpers-predict-all-types.md @@ -0,0 +1,18 @@ +# process_prediction_data errors when recipe drops rows + + Code + tune:::process_prediction_data(wflow_fit, static) + Condition + Error in `tune:::process_prediction_data()`: + ! Some assessment set rows are not available at prediction time. + i Consider using `skip = TRUE` on any recipe steps that remove rows to avoid calling them on the assessment set. + +# process_prediction_data errors when non-recipe preprocessor drops rows + + Code + tune:::process_prediction_data(wflow_fit, static) + Condition + Error in `tune:::process_prediction_data()`: + ! Some assessment set rows are not available at prediction time. + i Did your preprocessing steps filter or remove rows? + From 5d6f97573d7de6298a58873f3a707d7d35fc3ab9 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Fri, 20 Mar 2026 14:24:39 +0000 Subject: [PATCH 3/3] add test message for calibration set --- ...er-all-stages-helpers-predict-all-types.md | 9 +++++ ...ver-all-stages-helpers-predict-all-types.R | 38 +++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/tests/testthat/_snaps/loop-over-all-stages-helpers-predict-all-types.md b/tests/testthat/_snaps/loop-over-all-stages-helpers-predict-all-types.md index 830e3e70..fc62fc50 100644 --- a/tests/testthat/_snaps/loop-over-all-stages-helpers-predict-all-types.md +++ b/tests/testthat/_snaps/loop-over-all-stages-helpers-predict-all-types.md @@ -7,6 +7,15 @@ ! Some assessment set rows are not available at prediction time. i Consider using `skip = TRUE` on any recipe steps that remove rows to avoid calling them on the assessment set. +# process_prediction_data errors when recipe drops calibration rows + + Code + tune:::process_prediction_data(wflow_fit, static, source = "cal") + Condition + Error in `tune:::process_prediction_data()`: + ! Some calibration set rows are not available at prediction time. + i Consider using `skip = TRUE` on any recipe steps that remove rows to avoid calling them on the calibration set. + # process_prediction_data errors when non-recipe preprocessor drops rows Code diff --git a/tests/testthat/test-loop-over-all-stages-helpers-predict-all-types.R b/tests/testthat/test-loop-over-all-stages-helpers-predict-all-types.R index 57867fe4..88239419 100644 --- a/tests/testthat/test-loop-over-all-stages-helpers-predict-all-types.R +++ b/tests/testthat/test-loop-over-all-stages-helpers-predict-all-types.R @@ -1012,6 +1012,44 @@ test_that("process_prediction_data errors when recipe drops rows", { ) }) +test_that("process_prediction_data errors when recipe drops calibration rows", { + skip_if_not_installed("modeldata") + skip_if_not_installed("probably") + + cls <- make_post_data() + + filter_rec <- recipe(class ~ ., data = cls$data) |> + step_filter(non_linear_1 > 0, skip = FALSE) + + wflow <- workflow(filter_rec, logistic_reg(), cls_est_post) + # TODO just use cls$data + # after https://github.com/tidymodels/workflows/issues/315 is fixed + wflow_fit <- fit( + wflow, + data = cls$data, + data_calibration = cls$data |> filter(non_linear_1 > 0) + ) + + ctrl <- tune::control_grid() + data_1 <- tune:::.get_data_subsets(wflow, cls$rs$splits[[1]], cls$args) + + static <- tune:::make_static( + wflow, + param_info = wflow |> extract_parameter_set_dials(), + grid = tibble(), + metrics = metric_set(accuracy), + eval_time = NULL, + split_args = cls$args, + control = ctrl + ) + static <- tune:::update_static(static, data_1) + + expect_snapshot( + tune:::process_prediction_data(wflow_fit, static, source = "cal"), + error = TRUE + ) +}) + test_that("process_prediction_data errors when non-recipe preprocessor drops rows", { skip_if_not_installed("modeldata")