Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
@@ -1,29 +1,3 @@
# TODO old predict_model bits to check into
# # TODO patch since parsnip does not record the column names when Surv objects
# # are used with fit_xy()
# if (model$spec$mode == "censored regression") {
# model$preproc$y_var <- names(y_vals)
# }
#
# if (length(orig_rows) != nrow(x_vals)) {
# msg <- "Some assessment set rows are not available at prediction time."
#
# if (has_preprocessor_recipe(workflow)) {
# msg <-
# c(
# msg,
# i =
# "Consider using {.code skip = TRUE} on any recipe steps that
# remove rows to avoid calling them on the assessment set."
#
# )
# } else {
# msg <- c(msg, i = "Did your preprocessing steps filter or remove rows?")
# }
#
# cli::cli_abort(msg)
# }

trim_ipcw <- function(x) {
x$.weight_time <- NULL
x$.pred_censored <- NULL
Expand Down
18 changes: 16 additions & 2 deletions R/loop_over_all_stages-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,6 @@ predict_all_types <- function(

model_fit <- wflow_fit |> hardhat::extract_fit_parsnip()

# TODO tune::predict_model has some pre-prediction checks

sub_param <- names(submodel_grid)

# Convert argument names to parsnip format see #1011
Expand Down Expand Up @@ -345,6 +343,22 @@ process_prediction_data <- function(
.ind <- static$data[[source]]$ind

processed_data <- forge_from_workflow(.data, wflow_fit)

if (nrow(processed_data$predictors) != nrow(.data)) {
set_name <- if (source == "pred") "assessment" else "calibration"
msg <- "Some {set_name} set rows are not available at prediction time."
if (has_preprocessor_recipe(wflow_fit)) {
msg <- c(
msg,
i = "Consider using {.code skip = TRUE} on any recipe steps that
remove rows to avoid calling them on the {set_name} set."
)
} else {
msg <- c(msg, i = "Did your preprocessing steps filter or remove rows?")
}
cli::cli_abort(msg)
}

processed_data$outcomes <- processed_data$outcomes |>
dplyr::mutate(.row = .ind)
processed_data
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# process_prediction_data errors when recipe drops rows

Code
tune:::process_prediction_data(wflow_fit, static)
Condition
Error in `tune:::process_prediction_data()`:
! Some assessment set rows are not available at prediction time.
i Consider using `skip = TRUE` on any recipe steps that remove rows to avoid calling them on the assessment set.

# process_prediction_data errors when recipe drops calibration rows

Code
tune:::process_prediction_data(wflow_fit, static, source = "cal")
Condition
Error in `tune:::process_prediction_data()`:
! Some calibration set rows are not available at prediction time.
i Consider using `skip = TRUE` on any recipe steps that remove rows to avoid calling them on the calibration set.

# process_prediction_data errors when non-recipe preprocessor drops rows

Code
tune:::process_prediction_data(wflow_fit, static)
Condition
Error in `tune:::process_prediction_data()`:
! Some assessment set rows are not available at prediction time.
i Did your preprocessing steps filter or remove rows?

112 changes: 112 additions & 0 deletions tests/testthat/test-loop-over-all-stages-helpers-predict-all-types.R
Original file line number Diff line number Diff line change
Expand Up @@ -977,3 +977,115 @@ test_that("predict censored regression - submodels - no calibration", {
)
expect_equal(nrow(res_dyn), nrow(assessment(cens$rs$splits[[1]])))
})

# ------------------------------------------------------------------------------
# Row availability checks

test_that("process_prediction_data errors when recipe drops rows", {
skip_if_not_installed("modeldata")

cls <- make_post_data()

filter_rec <- recipe(class ~ ., data = cls$data) |>
step_filter(non_linear_1 > 0, skip = FALSE)

wflow <- workflow(filter_rec, logistic_reg())
wflow_fit <- fit(wflow, cls$data)

ctrl <- tune::control_grid()
data_1 <- tune:::.get_data_subsets(wflow, cls$rs$splits[[1]], cls$args)

static <- tune:::make_static(
wflow,
param_info = wflow |> extract_parameter_set_dials(),
grid = tibble(),
metrics = metric_set(accuracy),
eval_time = NULL,
split_args = cls$args,
control = ctrl
)
static <- tune:::update_static(static, data_1)

expect_snapshot(
tune:::process_prediction_data(wflow_fit, static),
error = TRUE
)
})

test_that("process_prediction_data errors when recipe drops calibration rows", {
skip_if_not_installed("modeldata")
skip_if_not_installed("probably")

cls <- make_post_data()

filter_rec <- recipe(class ~ ., data = cls$data) |>
step_filter(non_linear_1 > 0, skip = FALSE)

wflow <- workflow(filter_rec, logistic_reg(), cls_est_post)
# TODO just use cls$data
# after https://github.com/tidymodels/workflows/issues/315 is fixed
wflow_fit <- fit(
wflow,
data = cls$data,
data_calibration = cls$data |> filter(non_linear_1 > 0)
)

ctrl <- tune::control_grid()
data_1 <- tune:::.get_data_subsets(wflow, cls$rs$splits[[1]], cls$args)

static <- tune:::make_static(
wflow,
param_info = wflow |> extract_parameter_set_dials(),
grid = tibble(),
metrics = metric_set(accuracy),
eval_time = NULL,
split_args = cls$args,
control = ctrl
)
static <- tune:::update_static(static, data_1)

expect_snapshot(
tune:::process_prediction_data(wflow_fit, static, source = "cal"),
error = TRUE
)
})

test_that("process_prediction_data errors when non-recipe preprocessor drops rows", {
skip_if_not_installed("modeldata")

cls <- make_post_data()

wflow <- workflow(class ~ ., logistic_reg())
wflow_fit <- fit(wflow, cls$data)

ctrl <- tune::control_grid()
data_1 <- tune:::.get_data_subsets(wflow, cls$rs$splits[[1]], cls$args)

static <- tune:::make_static(
wflow,
param_info = wflow |> extract_parameter_set_dials(),
grid = tibble(),
metrics = metric_set(accuracy),
eval_time = NULL,
split_args = cls$args,
control = ctrl
)
static <- tune:::update_static(static, data_1)

local_mocked_bindings(
forge_from_workflow = function(new_data, workflow) {
res <- hardhat::forge(
new_data,
workflow$pre$mold$blueprint,
outcomes = TRUE
)
res$predictors <- res$predictors[1:5, ]
res
}
)

expect_snapshot(
tune:::process_prediction_data(wflow_fit, static),
error = TRUE
)
})
Loading