From 2ffb24f1f6654fe440d4ae740eaa439399e40709 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Mon, 26 Jan 2026 14:17:01 +0100 Subject: [PATCH 1/9] `use_claude_code()` --- .claude/.gitignore | 1 + .claude/CLAUDE.md | 82 ++++++ .claude/settings.json | 32 +++ .../skills/tidy-argument-checking/SKILL.md | 250 ++++++++++++++++++ .../skills/tidy-deprecate-function/SKILL.md | 156 +++++++++++ 5 files changed, 521 insertions(+) create mode 100644 .claude/.gitignore create mode 100644 .claude/CLAUDE.md create mode 100644 .claude/settings.json create mode 100644 .claude/skills/tidy-argument-checking/SKILL.md create mode 100644 .claude/skills/tidy-deprecate-function/SKILL.md diff --git a/.claude/.gitignore b/.claude/.gitignore new file mode 100644 index 00000000..93c0f73f --- /dev/null +++ b/.claude/.gitignore @@ -0,0 +1 @@ +settings.local.json diff --git a/.claude/CLAUDE.md b/.claude/CLAUDE.md new file mode 100644 index 00000000..5e261702 --- /dev/null +++ b/.claude/CLAUDE.md @@ -0,0 +1,82 @@ +## R package development + +### Key commands + +``` +# To run code +Rscript -e "devtools::load_all(); code" + +# To run all tests +Rscript -e "devtools::test()" + +# To run all tests for files starting with {name} +Rscript -e "devtools::test(filter = '^{name}')" + +# To run all tests for R/{name}.R +Rscript -e "devtools::test_active_file('R/{name}.R')" + +# To run a single test "blah" for R/{name}.R +Rscript -e "devtools::test_active_file('R/{name}.R', desc = 'blah')" + +# To redocument the package +Rscript -e "devtools::document()" + +# To check pkgdown documentation +Rscript -e "pkgdown::check_pkgdown()" + +# To check the package with R CMD check +Rscript -e "devtools::check()" + +# To format code +air format . +``` + +### Coding + +* Always run `air format .` after generating code +* Use the base pipe operator (`|>`) not the magrittr pipe (`%>%`) +* Don't use `_$x` or `_$[["x"]]` since this package must work on R 4.1. +* Use `\() ...` for single-line anonymous functions. For all other cases, use `function() {...}` + +### Testing + +- Tests for `R/{name}.R` go in `tests/testthat/test-{name}.R`. +- All new code should have an accompanying test. +- If there are existing tests, place new tests next to similar existing tests. +- Strive to keep your tests minimal with few comments. + +### Documentation + +- Every user-facing function should be exported and have roxygen2 documentation. +- Wrap roxygen comments at 80 characters. +- Internal functions should not have roxygen documentation. +- Whenever you add a new (non-internal) documentation topic, also add the topic to `_pkgdown.yml`. +- Always re-document the package after changing a roxygen2 comment. +- Use `pkgdown::check_pkgdown()` to check that all topics are included in the reference index. + +### `NEWS.md` + +- Every user-facing change should be given a bullet in `NEWS.md`. Do not add bullets for small documentation changes or internal refactorings. +- Each bullet should briefly describe the change to the end user and mention the related issue in parentheses. +- A bullet can consist of multiple sentences but should not contain any new lines (i.e. DO NOT line wrap). +- If the change is related to a function, put the name of the function early in the bullet. +- Order bullets alphabetically by function name. Put all bullets that don't mention function names at the beginning. + +### GitHub + +- If you use `gh` to retrieve information about an issue, always use `--comments` to read all the comments. + +### Writing + +- Use sentence case for headings. +- Use US English. + +### Proofreading + +If the user asks you to proofread a file, act as an expert proofreader and editor with a deep understanding of clear, engaging, and well-structured writing. + +Work paragraph by paragraph, always starting by making a TODO list that includes individual items for each top-level heading. + +Fix spelling, grammar, and other minor problems without asking the user. Label any unclear, confusing, or ambiguous sentences with a FIXME comment. + +Only report what you have changed. diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 00000000..6e4321da --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,32 @@ +{ + + "permissions": { + "$schema": "https://json.schemastore.org/claude-code-settings.json", + "defaultMode": "acceptEdits", + "allow": [ + "Bash(air:*)", + "Bash(cat:*)", + "Bash(find:*)", + "Bash(gh issue list:*)", + "Bash(gh issue view:*)", + "Bash(gh pr diff:*)", + "Bash(gh pr view:*)", + "Bash(git checkout:*)", + "Bash(git grep:*)", + "Bash(grep:*)", + "Bash(ls:*)", + "Bash(R:*)", + "Bash(rm:*)", + "Bash(Rscript:*)", + "Bash(sed:*)", + "Skill(*)", + "WebFetch(domain:cran.r-project.org)", + "WebFetch(domain:github.com)", + "WebFetch(domain:raw.githubusercontent.com)" + ], + "deny": [ + "Read(.Renviron)", + "Read(.env)" + ] + } +} diff --git a/.claude/skills/tidy-argument-checking/SKILL.md b/.claude/skills/tidy-argument-checking/SKILL.md new file mode 100644 index 00000000..ed46c19e --- /dev/null +++ b/.claude/skills/tidy-argument-checking/SKILL.md @@ -0,0 +1,250 @@ +--- +name: types-check +description: Validate function inputs in R using a standalone file of check_* functions. Use when writing exported R functions that need input validation, reviewing existing validation code, or when creating new input validation helpers. +--- + +# Input validation in R functions + +This skill describes tidyverse style for validating function inputs. It focuses on rlang's exported type checkers along with the standalone file of `check_*` functions. These functions are carefully designed to produce clear, actionable error messages: + +```r +check_string(123) +#> Error: `123` must be a single string, not the number 123. + +check_number_whole(3.14, min = 1, max = 10) +#> Error: `3.14` must be a whole number, not the number 3.14. +``` + +It assumes that the user has already run `usethis::use_standalone("r-lib/rlang", "types-check")`, and imports rlang in their package. If not, confirm with the user before continuing. + +## Function reference + +### Scalars (single values) + +Use scalar checkers when arguments parameterise the function (configuration flags, names, single counts), rather than represent vectors of user data. They all assert a single value. + +- `check_bool()`: Single TRUE/FALSE (use for flags/options) +- `check_string()`: Single string (allows empty `""` by default) +- `check_name()`: Single non-empty string (for variable names, symbols as strings) +- `check_number_whole()`: Single integer-like numeric value +- `check_number_decimal()`: Single numeric value (allows decimals) + +By default, scalar checkers do _not_ allow `NA` elements (`allow_na = FALSE`). Set `allow_na = TRUE` when missing values are allowed. + +With the number checkers you can use `min` and `max` arguments for range validation, and `allow_infinite` (default `TRUE` for decimals, `FALSE` for whole numbers). + +### Vectors + +- `check_logical()`: Logical vector of any length +- `check_character()`: Character vector of any length +- `check_data_frame()`: A data frame object + +By default, vector checkers allow `NA` elements (`allow_na = TRUE`). Set `allow_na = FALSE` when missing values are not allowed. + +### Optional values: `allow_null` + +Use `allow_null = TRUE` when `NULL` represents a valid "no value" state, similar to `Option` in Rust or `T | null` in TypeScript: + +```r +# NULL means "use default timeout" +check_number_decimal(timeout, allow_null = TRUE) +``` + +The tidyverse style guide recommends using `NULL` defaults instead of `missing()` defaults, so this pattern comes up often in practice. + +## Other helpers + +These functions are exported by rlang. + +- `arg_match()`: Validates enumerated choices. Use when an argument must be one of a known set of strings. + + ```r + # Validates and returns the matched value + my_plot <- function(color = c("red", "green", "blue")) { + color <- rlang::arg_match(color) + # ... + } + + my_plot("redd") + #> Error in `my_plot()`: + #> ! `color` must be one of "red", "green", or "blue", not "redd". + #> ℹ Did you mean "red"? + ``` + + Note that partial matching is an error, unlike `base::match.arg()`. + +- `check_exclusive()` ensures only one of two arguments can be supplied. Supplying both together (i.e. both of them are non-`NULL`) is an error. Use `.require = TRUE` if both can be omitted. + +- `check_required()`: Nice error message if required argument is not supplied. + +## `call` and `arg` arguments + +All check functions have `call` and `arg` arguments, but you should never use these unless you are creating your own `check_` function (see below for more details). + +## When to validate inputs + +**Validate at entry points, not everywhere.** + +Input validation should happen at the boundary between user code and your package's internal implementation: + +- **Exported functions**: Functions users call directly +- **Functions accepting user data**: Even internal functions if they directly consume user input, or external data (e.g. unserialised data) + +Once inputs are validated at these entry points, internal helper functions can trust the data they receive without checking again. + +A good analogy to keep in mind is gradual typing. Think of input validation like TypeScript type guards. Once you've validated data at the boundary, you can treat it as "typed" within your internal functions. Additional runtime checks are not needed. The entry point validates once, and all downstream code benefits. + +Exception: Validate when in doubt. Do validate in internal functions if: +- The cost of invalid data is high (data corruption, security issues) +- The function or context is complex and you want defensive checks + +Example of validating arguments of an exported function: + +```r +# Exported function: VALIDATE +#' @export +create_report <- function(title, n_rows) { + check_string(title) + check_number_whole(n_rows, min = 1) + + # Now call helpers with validated data + data <- generate_data(n_rows) + format_report(title, data) +} +``` + +Once data is validated at the entry point, internal helpers can skip validation: + +```r +# Internal helper: NO VALIDATION NEEDED +generate_data <- function(n_rows) { + # n_rows is already validated, just use it + data.frame( + id = seq_len(n_rows), + value = rnorm(n_rows) + ) +} + +# Internal helper: NO VALIDATION NEEDED +format_report <- function(title, data) { + # title and data are already validated, just use them + list( + title = title, + summary = summary(data), + rows = nrow(data) + ) +} +``` + +Note how the `data` generated by `generate_data()` doesn't need validation either. Internal code creating data in a trusted way (e.g. because it's simple or because it's covered by unit tests) doesn't require internal checks. + +## Early input checking + +Always validate inputs at the start of user-facing functions, before doing any work: + +```r +my_function <- function(x, name, env = caller_env()) { + check_logical(x) + check_name(name) + check_environment(env) + + # ... function body +} +``` + +Benefits: + +- This self-documents the types of the arguments +- Eager evaluation also reduces the risk of confusing lazy evaluation effects + +## Custom validation functions + +Most packages will need one or more unique checker functions. Sometimes it's sufficient to wrap existing check functions with custom arguments. In this case you just need to carefully pass through the `arg` and `call` arguments. In other cases, you want a completely new check in which case you can call `stop_input_type` with your own arguments. + +### Wrapping existing `check_` functions + +When creating a wrapper or helper function that calls `check_*` functions on behalf of another function, you **must** propagate the caller context. Otherwise, errors will point to your wrapper function instead of the actual entry point. + +Without proper propagation, error messages show the wrong function and argument names: + +```r +# WRONG: errors will point to check_positive's definition +check_positive <- function(x) { + check_number_whole(x, min = 1) +} + +my_function <- function(count) { + check_positive(count) +} + +my_function(-5) +#> Error in `check_positive()`: # Wrong! Should say `my_function()` +#> ! `x` must be a whole number larger than or equal to 1. # Wrong! Should say `count` +``` + +With proper propagation, errors correctly identify the entry point and argument: + +```r +# CORRECT: propagates context from the entry point +check_positive <- function(x, arg = caller_arg(x), call = caller_env()) { + check_number_whole(x, min = 1, arg = arg, call = call) +} + +my_function <- function(count) { + check_positive(count) +} + +my_function(-5) +#> Error in `my_function()`: # Correct! +#> ! `count` must be a whole number larger than or equal to 1. # Correct! +``` + +Note how `arg` and `call` are part of the function signature. That allows them to be wrapped again by another checking function that can pass down its own context. + +### Creating a new `check_` function + +When constructing your own `check_` function you can call `stop_input_type()` to take advantage of the existing infrastructure for generating error messages. +For example, imagine we wanted to create a function that checked that the input was a single date: + +```R +check_date <- function(x, ..., allow_null = FALSE, arg = caller_arg(x), call = caller_env()) { + if (!missing(x) && is.Date(x) && length(x) == 1) { + return(invisible()) + } + + stop_input_type( + x, + "a single Date", + ..., + allow_null = allow_null, + arg = arg, + call = call + ) +} +``` + +Note you must always check first that the input is not missing, as `stop_input_type()` handles this case specially. + +Sometimes you want to check if something is a compound type: + +```R +check_string_or_bool <- function(x, ..., arg = caller_arg(x), call = caller_env()) { + if (!missing(x)) { + if (is_string(x) || isTRUE(x) || isFALSE(x)) { + return(invisible()) + } + } + + stop_input_type( + x, + c("a string", "TRUE", "FALSE"), + ..., + arg = arg, + call = call + ) +} +``` + +Note that the second argument to `stop_input_type()` can take a vector, and it will automatically places commas and "and" in the appropriate locations. + +Generally, you should place this `check_` function close to the function that is usually used to construct the object being checked (e.g. close to the S3/S4/S7 constructor.) diff --git a/.claude/skills/tidy-deprecate-function/SKILL.md b/.claude/skills/tidy-deprecate-function/SKILL.md new file mode 100644 index 00000000..b2dfeb30 --- /dev/null +++ b/.claude/skills/tidy-deprecate-function/SKILL.md @@ -0,0 +1,156 @@ +--- +name: tidy-deprecate-function +description: Guide for deprecating R functions/arguments. Use when a user asks to deprecate a function or parameter, including adding lifecycle warnings, updating documentation, adding NEWS entries, and updating tests. +--- + +# Deprecate functions and function arguments + +Use this skill when deprecating functions or function parameters in this package. + +## Overview + +This skill guides you through the complete process of deprecating a function or parameter, ensuring all necessary changes are made consistently: + +1. Add deprecation warning using `lifecycle::deprecate_warn()`. +2. Silence deprecation warnings in existing tests. +3. Add lifecycle badge to documentation. +4. Add bullet point to NEWS.md. +5. Create test for deprecation warning. + +## Workflow + +### Step 1: Determine deprecation version + +Read the current version from DESCRIPTION and calculate the deprecation version: + +- Current version format: `MAJOR.MINOR.PATCH.9000` (development). +- Deprecation version: Next minor release `MAJOR.(MINOR+1).0`. +- Example: If current version is `2.5.1.9000`, deprecation version is `2.6.0`. + +### Step 2: Add `lifecycle::deprecate_warn()` call + +Add the deprecation warning to the function: + +```r +# For a deprecated function: +function_name <- function(...) { + lifecycle::deprecate_warn("X.Y.0", "function_name()", "replacement_function()") + # rest of function +} + +# For a deprecated parameter: +function_name <- function(param1, deprecated_param = deprecated()) { + if (lifecycle::is_present(deprecated_param)) { + lifecycle::deprecate_warn("X.Y.0", "function_name(deprecated_param)") + } + # rest of function +} +``` + +Key points: + +- First argument is the deprecation version string (e.g., "2.6.0"). +- Second argument describes what is deprecated (e.g., "function_name(param)"). +- Optional third argument suggests replacement. +- Use `lifecycle::is_present()` to check if a deprecated parameter was supplied. + +### Step 3: Update tests + +Find all existing tests that use the deprecated function or parameter and silence lifecycle warnings. Add at the beginning of test blocks that use the deprecated feature: + +```r +test_that("existing test with deprecated feature", { + withr::local_options(lifecycle_verbosity = "quiet") + + # existing test code +}) +``` + +Then add a new test to verify the deprecation message in the appropriate test file (usually `tests/testthat/test-{name}.R`): + +```r +test_that("function_name(deprecated_param) is deprecated", { + expect_snapshot(. <- function_name(deprecated_param = value)) +}) +``` + +You'll need to supply any additional arguments to create a valid call. + +Then run the tests and verify they pass. + +### Step 4: Update documentation + +For function deprecation, add to the description section: + +```r +#' @description +#' `r lifecycle::badge("deprecated")` +#' +#' This function is deprecated. Please use [replacement_function()] instead. +``` + +If the documentation does not already contain `@description`, you will need to add it. + +For argument deprecation, add to the appropriate `@param` tag: + +```r +#' @param deprecated_param `r lifecycle::badge("deprecated")` +``` + +When deprecating a function or parameter in favor of a replacement, add old/new examples to the `@examples` section to help users migrate. These should relace all existing examples. + +```r +#' @examples +#' # Old: +#' old_function(arg1, arg2) +#' # New: +#' replacement_function(arg1, arg2) +#' +#' # Old: +#' x <- "value" +#' old_function("prefix", x, "suffix") +#' # New: +#' replacement_function("prefix {x} suffix") +``` + +Key points: + +- Use "# Old:" and "# New:" comments to clearly show the transition. +- Include 2-3 practical examples covering common use cases. +- Make examples runnable and self-contained. +- Show how the new syntax differs from the old. + +Then re-document the package. + +### Step 5: Add NEWS entry + +Add a bullet point to the top of the "# packagename (development version)" section in NEWS.md: + +```markdown +# packagename (development version) + +* `function_name(parameter)` is deprecated and will be removed in a future + version. +* `function_name()` is deprecated. Use `replacement_function()` instead. +``` + +Place the entry: + +- In the lifecycle subsection if it exists, otherwise at the top level under development version. +- Include the replacement if known. +- Keep entries concise and actionable. + +## Implementation checklist + +When deprecating a function or parameter, ensure you: + +- [ ] Read DESCRIPTION to determine deprecation version. +- [ ] Add `lifecycle::deprecate_warn()` call in the function. +- [ ] Add `withr::local_options(lifecycle_verbosity = "quiet")` to existing tests. +- [ ] Create new test for deprecation warning using `expect_snapshot()`. +- [ ] Run tests to verify everything works. +- [ ] Add lifecycle badge to roxygen documentation. +- [ ] Add migration examples to `@examples` section (for function deprecation). +- [ ] Run `devtools::document()` to update documentation. +- [ ] Add bullet point to NEWS.md. +- [ ] Run `air format .` to format code. From c945d3c694e0ef2cad2accede4965b8f037c988e Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Mon, 26 Jan 2026 14:18:53 +0100 Subject: [PATCH 2/9] consolidate claude.md --- .claude/CLAUDE.md | 65 ++++++++++++++++++++++++++++++++++ CLAUDE.md | 90 ----------------------------------------------- 2 files changed, 65 insertions(+), 90 deletions(-) delete mode 100644 CLAUDE.md diff --git a/.claude/CLAUDE.md b/.claude/CLAUDE.md index 5e261702..debad2d0 100644 --- a/.claude/CLAUDE.md +++ b/.claude/CLAUDE.md @@ -1,3 +1,11 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Overview + +dials is an R package that provides infrastructure for creating and managing tuning parameter values in the tidymodels ecosystem. It defines parameter objects, sets of parameters, and methods for generating parameter grids for model tuning. + ## R package development ### Key commands @@ -80,3 +88,60 @@ Work paragraph by paragraph, always starting by making a TODO list that includes Fix spelling, grammar, and other minor problems without asking the user. Label any unclear, confusing, or ambiguous sentences with a FIXME comment. Only report what you have changed. + +## Architecture + +### Core Parameter System + +The package is built around two main parameter types: + +1. **`quant_param`**: Quantitative parameters (continuous or integer) + - Created via `new_quant_param()` in `R/constructors.R` + - Has `range` (lower/upper bounds), `inclusive`, optional `trans` (transformation), and `finalize` function + - Examples: `penalty()`, `mtry()`, `learn_rate()` + +2. **`qual_param`**: Qualitative parameters (categorical) + - Created via `new_qual_param()` in `R/constructors.R` + - Has discrete `values` (character or logical) + - Examples: `activation()`, `weight_func()` + +### Parameter Organization + +- **Individual parameters**: parameter definition files (`R/param_*.R`), each defining specific tuning parameters used across tidymodels +- **Parameter sets**: The `parameters` class (defined in `R/parameters.R`) groups multiple parameters into a data frame-like structure + +### Grid Generation + +Three main grid types (in `R/grids.R` and `R/space_filling.R`): + +1. **Regular grids** (`grid_regular()`): Factorial designs with evenly-spaced values +2. **Random grids** (`grid_random()`): Random sampling from parameter ranges +3. **Space-filling grids** (`grid_space_filling()`): Experimental designs (Latin hypercube, max entropy, etc.) that efficiently cover the parameter space + +All grid functions: +- Accept parameter objects or parameter sets +- Return tibbles with one column per parameter + +### Finalization System + +Many parameters have `unknown()` ranges that depend on the dataset (e.g., `mtry()` depends on the number of predictors). The finalization system (`R/finalize.R`) resolves these: + +- `finalize()`: Generic function that calls the parameter's embedded `finalize` function +- `get_*()`: Various functions that get and set parameter ranges based on data characteristics + +### Infrastructure Files + +Files prefixed with `aaa_` load first and define foundational classes: +- `R/aaa_ranges.R`: Handling and validation of parameter ranges +- `R/aaa_unknown.R`: The `unknown()` placeholder for unspecified parameter bounds +- `R/aaa_values.R`: Validation, generation, and transformation of parameter values + +Files prefixed with `compat-` provide compatibility with dplyr and vctrs for parameter objects. + +## Integration with tidymodels + +dials is infrastructure-level; it defines parameters but doesn't perform tuning. The tune package uses dials for actual hyperparameter tuning. Parameter objects integrate with: +- **parsnip**: Model specifications reference dials parameters +- **recipes**: Preprocessing steps use dials parameters +- **workflows**: Workflows combine models and preprocessing that utilize dials parameters +- **tune**: Grid search and optimization consume parameter grids diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index d61ca5f6..00000000 --- a/CLAUDE.md +++ /dev/null @@ -1,90 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Overview - -dials is an R package that provides infrastructure for creating and managing tuning parameter values in the tidymodels ecosystem. It defines parameter objects, sets of parameters, and methods for generating parameter grids for model tuning. - -## Key development commands - -General advice: -* When running R from the console, always run it with `--quiet --vanilla` -* Always run `air format .` after generating code - -### Testing - -- Use `devtools::test()` to run all tests -- Use `devtools::test_file("tests/testthat/test-filename.R")` to run tests in a specific file -- DO NOT USE `devtools::test_active_file()` -- All testing functions automatically load code; you don't need to. - -- All new code should have an accompanying test. -- Tests for `R/{name}.R` go in `tests/testthat/test-{name}.R`. -- If there are existing tests, place new tests next to similar existing tests. - -### Documentation - -- Run `devtools::document()` after changing any roxygen2 docs. -- Every user facing function should be exported and have roxygen2 documentation. -- Whenever you add a new documentation file, make sure to also add the topic name to `_pkgdown.yml`. -- Run `pkgdown::check_pkgdown()` to check that all topics are included in the reference index. -- Use sentence case for all headings - - -## Architecture - -### Core Parameter System - -The package is built around two main parameter types: - -1. **`quant_param`**: Quantitative parameters (continuous or integer) - - Created via `new_quant_param()` in `R/constructors.R` - - Has `range` (lower/upper bounds), `inclusive`, optional `trans` (transformation), and `finalize` function - - Examples: `penalty()`, `mtry()`, `learn_rate()` - -2. **`qual_param`**: Qualitative parameters (categorical) - - Created via `new_qual_param()` in `R/constructors.R` - - Has discrete `values` (character or logical) - - Examples: `activation()`, `weight_func()` - -### Parameter Organization - -- **Individual parameters**: parameter definition files (`R/param_*.R`), each defining specific tuning parameters used across tidymodels -- **Parameter sets**: The `parameters` class (defined in `R/parameters.R`) groups multiple parameters into a data frame-like structure - -### Grid Generation - -Three main grid types (in `R/grids.R` and `R/space_filling.R`): - -1. **Regular grids** (`grid_regular()`): Factorial designs with evenly-spaced values -2. **Random grids** (`grid_random()`): Random sampling from parameter ranges -3. **Space-filling grids** (`grid_space_filling()`): Experimental designs (Latin hypercube, max entropy, etc.) that efficiently cover the parameter space - -All grid functions: -- Accept parameter objects or parameter sets -- Return tibbles with one column per parameter - -### Finalization System - -Many parameters have `unknown()` ranges that depend on the dataset (e.g., `mtry()` depends on the number of predictors). The finalization system (`R/finalize.R`) resolves these: - -- `finalize()`: Generic function that calls the parameter's embedded `finalize` function -- `get_*()`: Various functions that get and set parameter ranges based on data characteristics - -### Infrastructure Files - -Files prefixed with `aaa_` load first and define foundational classes: -- `R/aaa_ranges.R`: Handling and validation of parameter ranges -- `R/aaa_unknown.R`: The `unknown()` placeholder for unspecified parameter bounds -- `R/aaa_values.R`: Validation, generation, and transformation of parameter values - -Files prefixed with `compat-` provide compatibility with dplyr and vctrs for parameter objects. - -## Integration with tidymodels - -dials is infrastructure-level; it defines parameters but doesn't perform tuning. The tune package uses dials for actual hyperparameter tuning. Parameter objects integrate with: -- **parsnip**: Model specifications reference dials parameters -- **recipes**: Preprocessing steps use dials parameters -- **workflows**: Workflows combine models and preprocessing that utilize dials parameters -- **tune**: Grid search and optimization consume parameter grids From f28926c6d9fe511db6d3f698722d4d7f407ad28e Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 28 Jan 2026 10:37:35 +0000 Subject: [PATCH 3/9] Additional helper functions `check_inherits()` and `check_levels()` --- R/misc.R | 57 +++++++++++++++++++++++++++++++++++ tests/testthat/_snaps/misc.md | 56 ++++++++++++++++++++++++++++++++++ tests/testthat/test-misc.R | 22 ++++++++++++++ 3 files changed, 135 insertions(+) diff --git a/R/misc.R b/R/misc.R index 4d7458d2..51510cb7 100644 --- a/R/misc.R +++ b/R/misc.R @@ -218,3 +218,60 @@ check_param <- function( call = call ) } + +check_inherits <- function( + x, + class, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env() +) { + check_dots_empty() + if (!missing(x)) { + if (inherits(x, class)) { + return(invisible(NULL)) + } + if (allow_null && is.null(x)) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + paste0("a <", class, "> object"), + ..., + allow_null = allow_null, + arg = arg, + call = call + ) +} + +check_levels <- function( + levels, + ..., + allow_null = FALSE, + arg = caller_arg(levels), + call = caller_env() +) { + check_dots_empty() + if (!missing(levels)) { + if ( + is.numeric(levels) && all(levels >= 1) && all(levels == floor(levels)) + ) { + return(invisible(NULL)) + } + if (allow_null && is.null(levels)) { + return(invisible(NULL)) + } + } + + stop_input_type( + levels, + "a positive integer or a vector of positive integers", + ..., + allow_null = allow_null, + arg = arg, + call = call + ) +} diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index 51350764..f3e9bc05 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -70,6 +70,62 @@ Error: ! `1:2` must be a logical vector, not an integer vector. +# check_inherits() + + Code + check_inherits("not a data frame", "data.frame") + Condition + Error: + ! `"not a data frame"` must be a object, not the string "not a data frame". + +--- + + Code + check_inherits(NULL, "data.frame") + Condition + Error: + ! `NULL` must be a object, not `NULL`. + +# check_levels() + + Code + check_levels(0) + Condition + Error: + ! `0` must be a positive integer or a vector of positive integers, not the number 0. + +--- + + Code + check_levels(-1) + Condition + Error: + ! `-1` must be a positive integer or a vector of positive integers, not the number -1. + +--- + + Code + check_levels(1.5) + Condition + Error: + ! `1.5` must be a positive integer or a vector of positive integers, not the number 1.5. + +--- + + Code + check_levels("a") + Condition + Error: + ! `"a"` must be a positive integer or a vector of positive integers, not the string "a". + +--- + + Code + check_levels(NULL) + Condition + Error: + ! `NULL` must be a positive integer or a vector of positive integers, not `NULL`. + # vctrs-helpers-parameters Code diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index 0d56f1ca..48f48840 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -56,6 +56,28 @@ test_that("check_inclusive()", { expect_snapshot(error = TRUE, check_inclusive(1:2)) }) +test_that("check_inherits()", { + expect_no_error(check_inherits(mtcars, "data.frame")) + expect_no_error(check_inherits(NULL, "data.frame", allow_null = TRUE)) + + expect_snapshot( + error = TRUE, + check_inherits("not a data frame", "data.frame") + ) + expect_snapshot(error = TRUE, check_inherits(NULL, "data.frame")) +}) + +test_that("check_levels()", { + expect_no_error(check_levels(1)) + expect_no_error(check_levels(1:5)) + expect_no_error(check_levels(NULL, allow_null = TRUE)) + + expect_snapshot(error = TRUE, check_levels(0)) + expect_snapshot(error = TRUE, check_levels(-1)) + expect_snapshot(error = TRUE, check_levels(1.5)) + expect_snapshot(error = TRUE, check_levels("a")) + expect_snapshot(error = TRUE, check_levels(NULL)) +}) test_that("vctrs-helpers-parameters", { expect_false(dials:::is_parameters(2)) From d39b8cc9008a1bbb08140205ab8186f722ea9095 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 28 Jan 2026 10:57:56 +0000 Subject: [PATCH 4/9] Improve input checks for regular and random grids --- R/grids.R | 12 +++++++- tests/testthat/_snaps/grids.md | 56 ++++++++++++++++++++++++++++++++++ tests/testthat/test-grids.R | 16 ++++++++++ 3 files changed, 83 insertions(+), 1 deletion(-) diff --git a/R/grids.R b/R/grids.R index a755cc02..30023947 100644 --- a/R/grids.R +++ b/R/grids.R @@ -134,6 +134,8 @@ make_regular_grid <- function( filter = NULL, call = caller_env() ) { + check_levels(levels, call = call) + check_bool(original, call = call) validate_params(..., call = call) filter_quo <- enquo(filter) param_quos <- quos(...) @@ -156,10 +158,16 @@ make_regular_grid <- function( levels <- levels[names(params)] } else if (any(rlang::has_name(levels, names(params)))) { cli::cli_abort( - "Elements of {.arg levels} should either be all named or unnamed, + "Elements of {.arg levels} should either be all named or unnamed, not mixed.", call = call ) + } else if (!is.null(names(levels))) { + cli::cli_abort( + "The names of {.arg levels} must match the parameter names: + {.val {names(params)}}.", + call = call + ) } param_seq <- map2(params, as.list(levels), value_seq, original = original) } @@ -259,6 +267,8 @@ make_random_grid <- function( filter = NULL, call = caller_env() ) { + check_number_whole(size, min = 1, call = call) + check_bool(original, call = call) validate_params(..., call = call) filter_quo <- enquo(filter) param_quos <- quos(...) diff --git a/tests/testthat/_snaps/grids.md b/tests/testthat/_snaps/grids.md index aa6e3a5d..4264d411 100644 --- a/tests/testthat/_snaps/grids.md +++ b/tests/testthat/_snaps/grids.md @@ -23,6 +23,14 @@ Error in `grid_regular()`: ! Elements of `levels` should either be all named or unnamed, not mixed. +--- + + Code + grid_regular(mixture(), trees(), levels = c(wrong = 2, names = 4)) + Condition + Error in `grid_regular()`: + ! The names of `levels` must match the parameter names: "mixture" and "trees". + # wrong argument name Code @@ -59,6 +67,54 @@ ! `size` is not an argument to `grid_regular()`. i Did you mean `levels`? +# grid_random validates inputs + + Code + grid_random(penalty(), size = "five") + Condition + Error in `grid_random()`: + ! `size` must be a whole number, not the string "five". + +--- + + Code + grid_random(penalty(), size = -1) + Condition + Error in `grid_random()`: + ! `size` must be a whole number larger than or equal to 1, not the number -1. + +--- + + Code + grid_random(penalty(), original = "yes") + Condition + Error in `grid_random()`: + ! `original` must be `TRUE` or `FALSE`, not the string "yes". + +# grid_regular validates inputs + + Code + grid_regular(penalty(), levels = "three") + Condition + Error in `grid_regular()`: + ! `levels` must be a positive integer or a vector of positive integers, not the string "three". + +--- + + Code + grid_regular(penalty(), levels = -1) + Condition + Error in `grid_regular()`: + ! `levels` must be a positive integer or a vector of positive integers, not the number -1. + +--- + + Code + grid_regular(penalty(), original = "yes") + Condition + Error in `grid_regular()`: + ! `original` must be `TRUE` or `FALSE`, not the string "yes". + # new param grid from conventional data frame Code diff --git a/tests/testthat/test-grids.R b/tests/testthat/test-grids.R index 7095d1e4..8c8dc839 100644 --- a/tests/testthat/test-grids.R +++ b/tests/testthat/test-grids.R @@ -46,6 +46,10 @@ test_that("regular grid", { error = TRUE, grid_regular(mixture(), trees(), levels = c(2, trees = 4)) ) + expect_snapshot( + error = TRUE, + grid_regular(mixture(), trees(), levels = c(wrong = 2, names = 4)) + ) }) test_that("random grid", { @@ -98,6 +102,18 @@ test_that("filter arg yields same results", { }) +test_that("grid_random validates inputs", { + expect_snapshot(error = TRUE, grid_random(penalty(), size = "five")) + expect_snapshot(error = TRUE, grid_random(penalty(), size = -1)) + expect_snapshot(error = TRUE, grid_random(penalty(), original = "yes")) +}) + +test_that("grid_regular validates inputs", { + expect_snapshot(error = TRUE, grid_regular(penalty(), levels = "three")) + expect_snapshot(error = TRUE, grid_regular(penalty(), levels = -1)) + expect_snapshot(error = TRUE, grid_regular(penalty(), original = "yes")) +}) + test_that("new param grid from conventional data frame", { x <- data.frame(num_comp = 1:3) From 9c1e58a7c8cdda25d0502a75a78a86c51aedba4a Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 28 Jan 2026 11:51:52 +0000 Subject: [PATCH 5/9] Improve input checking for space-filling grids --- R/space_filling.R | 18 ++++++++++++ tests/testthat/_snaps/space_filling.md | 40 ++++++++++++++++++++++++++ tests/testthat/test-space_filling.R | 11 +++++++ 3 files changed, 69 insertions(+) diff --git a/R/space_filling.R b/R/space_filling.R index f1d5ed81..c7341d7c 100644 --- a/R/space_filling.R +++ b/R/space_filling.R @@ -206,6 +206,14 @@ make_sfd <- function( original = TRUE, call = caller_env() ) { + check_number_whole(size, min = 1, call = call) + check_number_decimal( + variogram_range, + min = 0, + call = call + ) + check_number_whole(iter, min = 1, call = call) + check_bool(original, call = call) type <- rlang::arg_match(type, sfd_types) validate_params(..., call = call) param_quos <- quos(...) @@ -397,6 +405,14 @@ make_max_entropy_grid <- function( iter = 1000, call = caller_env() ) { + check_number_whole(size, min = 1, call = call) + check_number_decimal( + variogram_range, + min = 0, + call = call + ) + check_number_whole(iter, min = 1, call = call) + check_bool(original, call = call) validate_params(..., call = call) param_quos <- quos(...) params <- map(param_quos, eval_tidy) @@ -491,6 +507,8 @@ make_latin_hypercube_grid <- function( original = TRUE, call = caller_env() ) { + check_number_whole(size, min = 1, call = call) + check_bool(original, call = call) validate_params(..., call = call) param_quos <- quos(...) params <- map(param_quos, eval_tidy) diff --git a/tests/testthat/_snaps/space_filling.md b/tests/testthat/_snaps/space_filling.md index 5e85e270..eec6b515 100644 --- a/tests/testthat/_snaps/space_filling.md +++ b/tests/testthat/_snaps/space_filling.md @@ -73,3 +73,43 @@ ! `levels` is not an argument to `grid_space_filling()`. i Did you mean `size`? +# grid_space_filling validates inputs + + Code + grid_space_filling(penalty(), size = "five") + Condition + Error in `grid_space_filling()`: + ! `size` must be a whole number, not the string "five". + +--- + + Code + grid_space_filling(penalty(), size = -1) + Condition + Error in `grid_space_filling()`: + ! `size` must be a whole number larger than or equal to 1, not the number -1. + +--- + + Code + grid_space_filling(penalty(), variogram_range = -1) + Condition + Error in `grid_space_filling()`: + ! `variogram_range` must be a number larger than or equal to 0, not the number -1. + +--- + + Code + grid_space_filling(penalty(), iter = "many") + Condition + Error in `grid_space_filling()`: + ! `iter` must be a whole number, not the string "many". + +--- + + Code + grid_space_filling(penalty(), original = "yes") + Condition + Error in `grid_space_filling()`: + ! `original` must be `TRUE` or `FALSE`, not the string "yes". + diff --git a/tests/testthat/test-space_filling.R b/tests/testthat/test-space_filling.R index d0ec9ed6..7323d17b 100644 --- a/tests/testthat/test-space_filling.R +++ b/tests/testthat/test-space_filling.R @@ -331,6 +331,17 @@ test_that("1-point grid", { expect_equal(nrow(grid), 1L) }) +test_that("grid_space_filling validates inputs", { + expect_snapshot(error = TRUE, grid_space_filling(penalty(), size = "five")) + expect_snapshot(error = TRUE, grid_space_filling(penalty(), size = -1)) + expect_snapshot( + error = TRUE, + grid_space_filling(penalty(), variogram_range = -1) + ) + expect_snapshot(error = TRUE, grid_space_filling(penalty(), iter = "many")) + expect_snapshot(error = TRUE, grid_space_filling(penalty(), original = "yes")) +}) + test_that("pre-made designs respect the 'original argument", { # See issue #409 From 2ea0758185bb25fde46e2536a49eace1139706a8 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 28 Jan 2026 13:16:23 +0000 Subject: [PATCH 6/9] Improve input checks for finalizing functions --- R/finalize.R | 16 +++++++-- R/misc.R | 22 ++++++++++++ tests/testthat/_snaps/finalize.md | 60 +++++++++++++++++++++++++++++-- tests/testthat/_snaps/misc.md | 48 +++++++++++++++++++++++++ tests/testthat/test-finalize.R | 13 +++++++ tests/testthat/test-misc.R | 12 +++++++ 6 files changed, 167 insertions(+), 4 deletions(-) diff --git a/R/finalize.R b/R/finalize.R index 19ec61b0..7628df5c 100644 --- a/R/finalize.R +++ b/R/finalize.R @@ -94,12 +94,14 @@ finalize <- function(object, ...) { #' @export #' @rdname finalize finalize.list <- function(object, x, force = TRUE, ...) { + check_bool(force) map(object, finalize, x, force, ...) } #' @export #' @rdname finalize finalize.param <- function(object, x, force = TRUE, ...) { + check_bool(force) if (is.null(object$finalize)) { return(object) } @@ -114,7 +116,7 @@ safe_finalize <- function(object, x, force = TRUE, ...) { if (all(is.na(object))) { res <- NA } else { - res <- finalize(object, x, force = force, ...) + res <- finalize(object, x = x, force = force, ...) } res } @@ -122,7 +124,7 @@ safe_finalize <- function(object, x, force = TRUE, ...) { #' @export #' @rdname finalize finalize.parameters <- function(object, x, force = TRUE, ...) { - object$object <- map(object$object, safe_finalize, x, force, ...) + object$object <- map(object$object, safe_finalize, x = x, force = force, ...) object } @@ -153,6 +155,7 @@ finalize.default <- function(object, x, force = TRUE, ...) { #' @rdname finalize get_p <- function(object, x, log_vals = FALSE, ...) { check_param(object) + check_bool(log_vals) rngs <- range_get(object, original = FALSE) if (!is_unknown(rngs$upper)) { @@ -182,6 +185,7 @@ get_p <- function(object, x, log_vals = FALSE, ...) { #' @export #' @rdname finalize get_log_p <- function(object, x, ...) { + check_param(object) get_p(object, x, log_vals = TRUE, ...) } @@ -189,6 +193,8 @@ get_log_p <- function(object, x, ...) { #' @rdname finalize get_n_frac <- function(object, x, log_vals = FALSE, frac = 1 / 3, ...) { check_param(object) + check_bool(log_vals) + check_number_decimal(frac, min = 0, max = 1) rngs <- range_get(object, original = FALSE) if (!is_unknown(rngs$upper)) { @@ -224,6 +230,10 @@ get_n_frac_range <- function( frac = c(1 / 10, 5 / 10), ... ) { + check_param(object) + check_bool(log_vals) + check_frac_range(frac) + rngs <- range_get(object, original = FALSE) if (!is_unknown(rngs$upper)) { return(object) @@ -254,12 +264,14 @@ get_n_frac_range <- function( #' @export #' @rdname finalize get_n <- function(object, x, log_vals = FALSE, ...) { + check_param(object) get_n_frac(object, x, log_vals, frac = 1, ...) } #' @export #' @rdname finalize get_rbf_range <- function(object, x, seed = sample.int(10^5, 1), ...) { + check_param(object) rlang::check_installed("kernlab") suppressPackageStartupMessages(requireNamespace("kernlab", quietly = TRUE)) x_mat <- as.matrix(x) diff --git a/R/misc.R b/R/misc.R index 51510cb7..d40204cb 100644 --- a/R/misc.R +++ b/R/misc.R @@ -275,3 +275,25 @@ check_levels <- function( call = call ) } + +check_frac_range <- function(x, ..., arg = caller_arg(x), call = caller_env()) { + check_dots_empty() + if ( + !missing(x) && + is.numeric(x) && + length(x) == 2 && + !anyNA(x) && + all(x >= 0) && + all(x <= 1) + ) { + return(invisible(NULL)) + } + + stop_input_type( + x, + "a numeric vector of length 2 with values between 0 and 1", + ..., + arg = arg, + call = call + ) +} diff --git a/tests/testthat/_snaps/finalize.md b/tests/testthat/_snaps/finalize.md index 11186691..63520fb4 100644 --- a/tests/testthat/_snaps/finalize.md +++ b/tests/testthat/_snaps/finalize.md @@ -1,3 +1,59 @@ +# finalize validates inputs + + Code + finalize(penalty(), mtcars, force = "yes") + Condition + Error in `finalize()`: + ! `force` must be `TRUE` or `FALSE`, not the string "yes". + +--- + + Code + get_p(penalty(), mtcars, log_vals = "yes") + Condition + Error in `get_p()`: + ! `log_vals` must be `TRUE` or `FALSE`, not the string "yes". + +--- + + Code + get_n_frac(mtry(), mtcars, frac = "half") + Condition + Error in `get_n_frac()`: + ! `frac` must be a number, not the string "half". + +--- + + Code + get_n_frac(mtry(), mtcars, frac = 1.5) + Condition + Error in `get_n_frac()`: + ! `frac` must be a number between 0 and 1, not the number 1.5. + +--- + + Code + get_n_frac_range(mtry(), mtcars, frac = 0.5) + Condition + Error in `get_n_frac_range()`: + ! `frac` must be a numeric vector of length 2 with values between 0 and 1, not the number 0.5. + +--- + + Code + get_n_frac_range(mtry(), mtcars, frac = c(0.1, 1.5)) + Condition + Error in `get_n_frac_range()`: + ! `frac` must be a numeric vector of length 2 with values between 0 and 1, not a double vector. + +--- + + Code + get_rbf_range("not a param", mtcars) + Condition + Error in `get_rbf_range()`: + ! `object` must be a single parameter object, not the string "not a param". + # estimate columns Code @@ -27,7 +83,7 @@ Code get_n(1:10) Condition - Error in `get_n_frac()`: + Error in `get_n()`: ! `object` must be a single parameter object, not an integer vector. --- @@ -35,7 +91,7 @@ Code get_n(1:10, 1:10) Condition - Error in `get_n_frac()`: + Error in `get_n()`: ! `object` must be a single parameter object, not an integer vector. --- diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index f3e9bc05..60ff5a19 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -126,6 +126,54 @@ Error: ! `NULL` must be a positive integer or a vector of positive integers, not `NULL`. +# check_frac_range() + + Code + check_frac_range("not numeric") + Condition + Error: + ! `"not numeric"` must be a numeric vector of length 2 with values between 0 and 1, not the string "not numeric". + +--- + + Code + check_frac_range(0.5) + Condition + Error: + ! `0.5` must be a numeric vector of length 2 with values between 0 and 1, not the number 0.5. + +--- + + Code + check_frac_range(c(0.1, 0.5, 0.9)) + Condition + Error: + ! `c(0.1, 0.5, 0.9)` must be a numeric vector of length 2 with values between 0 and 1, not a double vector. + +--- + + Code + check_frac_range(c(-0.1, 0.5)) + Condition + Error: + ! `c(-0.1, 0.5)` must be a numeric vector of length 2 with values between 0 and 1, not a double vector. + +--- + + Code + check_frac_range(c(0.1, 1.5)) + Condition + Error: + ! `c(0.1, 1.5)` must be a numeric vector of length 2 with values between 0 and 1, not a double vector. + +--- + + Code + check_frac_range(c(0.1, NA)) + Condition + Error: + ! `c(0.1, NA)` must be a numeric vector of length 2 with values between 0 and 1, not a double vector. + # vctrs-helpers-parameters Code diff --git a/tests/testthat/test-finalize.R b/tests/testthat/test-finalize.R index 790df1c9..85231edb 100644 --- a/tests/testthat/test-finalize.R +++ b/tests/testthat/test-finalize.R @@ -1,3 +1,16 @@ +test_that("finalize validates inputs", { + expect_snapshot(error = TRUE, finalize(penalty(), mtcars, force = "yes")) + expect_snapshot(error = TRUE, get_p(penalty(), mtcars, log_vals = "yes")) + expect_snapshot(error = TRUE, get_n_frac(mtry(), mtcars, frac = "half")) + expect_snapshot(error = TRUE, get_n_frac(mtry(), mtcars, frac = 1.5)) + expect_snapshot(error = TRUE, get_n_frac_range(mtry(), mtcars, frac = 0.5)) + expect_snapshot( + error = TRUE, + get_n_frac_range(mtry(), mtcars, frac = c(0.1, 1.5)) + ) + expect_snapshot(error = TRUE, get_rbf_range("not a param", mtcars)) +}) + test_that("estimate columns", { expect_snapshot(error = TRUE, get_p(1:10)) expect_snapshot(error = TRUE, get_p(1:10, 1:10)) diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index 48f48840..66013c95 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -79,6 +79,18 @@ test_that("check_levels()", { expect_snapshot(error = TRUE, check_levels(NULL)) }) +test_that("check_frac_range()", { + expect_no_error(check_frac_range(c(0.1, 0.5))) + expect_no_error(check_frac_range(c(0, 1))) + + expect_snapshot(error = TRUE, check_frac_range("not numeric")) + expect_snapshot(error = TRUE, check_frac_range(0.5)) + expect_snapshot(error = TRUE, check_frac_range(c(0.1, 0.5, 0.9))) + expect_snapshot(error = TRUE, check_frac_range(c(-0.1, 0.5))) + expect_snapshot(error = TRUE, check_frac_range(c(0.1, 1.5))) + expect_snapshot(error = TRUE, check_frac_range(c(0.1, NA))) +}) + test_that("vctrs-helpers-parameters", { expect_false(dials:::is_parameters(2)) expect_snapshot( From 7c049495b283c60c280168f69b6617e9edad9a02 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 28 Jan 2026 13:41:16 +0000 Subject: [PATCH 7/9] `value_*()`: improve input checks --- R/aaa_values.R | 13 ++++- tests/testthat/_snaps/aaa_values.md | 88 +++++++++++++++++++++++++++++ tests/testthat/test-aaa_values.R | 26 +++++++++ 3 files changed, 126 insertions(+), 1 deletion(-) diff --git a/R/aaa_values.R b/R/aaa_values.R index 5731b766..2462648a 100644 --- a/R/aaa_values.R +++ b/R/aaa_values.R @@ -73,6 +73,7 @@ #' #' @export value_validate <- function(object, values, ..., call = caller_env()) { + check_param(object) res <- switch( object$type, double = , @@ -125,9 +126,14 @@ value_validate_qual <- function(object, values, ..., call = caller_env()) { #' @export #' @rdname value_validate value_seq <- function(object, n, original = TRUE) { + check_param(object) if (inherits(object, "quant_param")) { range_validate(object, object$range, ukn_ok = FALSE) + check_number_whole(n, min = 1) + } else { + check_number_whole(n, min = 1, allow_infinite = TRUE) } + check_bool(original) res <- switch( object$type, @@ -214,6 +220,9 @@ value_seq_qual <- function(object, n) { #' @export #' @rdname value_validate value_sample <- function(object, n, original = TRUE) { + check_param(object) + check_number_whole(n, min = 1) + check_bool(original) if (inherits(object, "quant_param")) { range_validate(object, object$range, ukn_ok = FALSE) } @@ -334,6 +343,7 @@ value_samp_qual <- function(object, n) { #' @export #' @rdname value_validate value_transform <- function(object, values) { + check_param(object) check_for_unknowns(values) if (is.null(object$trans)) { @@ -353,6 +363,7 @@ trans_wrap <- function(x, object) { #' @export #' @rdname value_validate value_inverse <- function(object, values) { + check_param(object) check_for_unknowns(values) if (is.null(object$trans)) { @@ -373,11 +384,11 @@ inv_wrap <- function(x, object) { #' @export #' @rdname value_validate value_set <- function(object, values) { + check_param(object) check_for_unknowns(values) if (length(values) == 0) { cli::cli_abort("{.arg values} must have at least one element.") } - check_param(object) if (inherits(object, "quant_param")) { object <- diff --git a/tests/testthat/_snaps/aaa_values.md b/tests/testthat/_snaps/aaa_values.md index 79a5122b..7c94649f 100644 --- a/tests/testthat/_snaps/aaa_values.md +++ b/tests/testthat/_snaps/aaa_values.md @@ -1,3 +1,91 @@ +# value_validate validates inputs + + Code + value_validate("not a param", 1) + Condition + Error in `value_validate()`: + ! `object` must be a single parameter object, not the string "not a param". + +# value_seq validates inputs + + Code + value_seq("not a param", 5) + Condition + Error in `value_seq()`: + ! `object` must be a single parameter object, not the string "not a param". + +--- + + Code + value_seq(penalty(), "five") + Condition + Error in `value_seq()`: + ! `n` must be a whole number, not the string "five". + +--- + + Code + value_seq(penalty(), -1) + Condition + Error in `value_seq()`: + ! `n` must be a whole number larger than or equal to 1, not the number -1. + +--- + + Code + value_seq(penalty(), 5, original = "yes") + Condition + Error in `value_seq()`: + ! `original` must be `TRUE` or `FALSE`, not the string "yes". + +# value_sample validates inputs + + Code + value_sample("not a param", 5) + Condition + Error in `value_sample()`: + ! `object` must be a single parameter object, not the string "not a param". + +--- + + Code + value_sample(penalty(), "five") + Condition + Error in `value_sample()`: + ! `n` must be a whole number, not the string "five". + +--- + + Code + value_sample(penalty(), -1) + Condition + Error in `value_sample()`: + ! `n` must be a whole number larger than or equal to 1, not the number -1. + +--- + + Code + value_sample(penalty(), 5, original = "yes") + Condition + Error in `value_sample()`: + ! `original` must be `TRUE` or `FALSE`, not the string "yes". + +# value_transform validates inputs + + Code + value_transform("not a param", 1:3) + Condition + Error in `value_transform()`: + ! `object` must be a single parameter object, not the string "not a param". + +# value_inverse validates inputs + + Code + value_inverse("not a param", 1:3) + Condition + Error in `value_inverse()`: + ! `object` must be a single parameter object, not the string "not a param". + # transforms with unknowns Code diff --git a/tests/testthat/test-aaa_values.R b/tests/testthat/test-aaa_values.R index 6a54e1a2..99c49105 100644 --- a/tests/testthat/test-aaa_values.R +++ b/tests/testthat/test-aaa_values.R @@ -1,3 +1,29 @@ +test_that("value_validate validates inputs", { + expect_snapshot(error = TRUE, value_validate("not a param", 1)) +}) + +test_that("value_seq validates inputs", { + expect_snapshot(error = TRUE, value_seq("not a param", 5)) + expect_snapshot(error = TRUE, value_seq(penalty(), "five")) + expect_snapshot(error = TRUE, value_seq(penalty(), -1)) + expect_snapshot(error = TRUE, value_seq(penalty(), 5, original = "yes")) +}) + +test_that("value_sample validates inputs", { + expect_snapshot(error = TRUE, value_sample("not a param", 5)) + expect_snapshot(error = TRUE, value_sample(penalty(), "five")) + expect_snapshot(error = TRUE, value_sample(penalty(), -1)) + expect_snapshot(error = TRUE, value_sample(penalty(), 5, original = "yes")) +}) + +test_that("value_transform validates inputs", { + expect_snapshot(error = TRUE, value_transform("not a param", 1:3)) +}) + +test_that("value_inverse validates inputs", { + expect_snapshot(error = TRUE, value_inverse("not a param", 1:3)) +}) + test_that("transforms with unknowns", { expect_snapshot( error = TRUE, From 455bee06e84248399f6d7fa6e12e8de7b177b0cb Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 28 Jan 2026 13:49:57 +0000 Subject: [PATCH 8/9] `range_*()`: improve input checks --- R/aaa_ranges.R | 32 +++++++++++++-------------- tests/testthat/_snaps/aaa_ranges.md | 34 ++++++++++++++++++++++++++++- tests/testthat/test-aaa_ranges.R | 13 +++++++++++ 3 files changed, 62 insertions(+), 17 deletions(-) diff --git a/R/aaa_ranges.R b/R/aaa_ranges.R index 5b9d1c9d..989b28fd 100644 --- a/R/aaa_ranges.R +++ b/R/aaa_ranges.R @@ -51,6 +51,8 @@ range_validate <- function( ..., call = caller_env() ) { + check_inherits(object, "quant_param", call = call) + check_bool(ukn_ok, call = call) ukn_txt <- if (ukn_ok) { c(i = "{.code Inf} and {.code unknown()} are acceptable values.") } else { @@ -109,6 +111,8 @@ range_validate <- function( #' @export #' @rdname range_validate range_get <- function(object, original = TRUE) { + check_inherits(object, "quant_param") + check_bool(original) if (original & !is.null(object$trans)) { res <- map(object$range, inv_wrap, object) } else { @@ -120,28 +124,24 @@ range_get <- function(object, original = TRUE) { #' @export #' @rdname range_validate range_set <- function(object, range, call = caller_env()) { + check_inherits(object, "quant_param", call = call) + if (length(range) != 2) { cli::cli_abort( "{.arg range} should have two elements, not {length(range)}.", call = call ) } - if (inherits(object, "quant_param")) { - object <- - new_quant_param( - type = object$type, - range = range, - inclusive = object$inclusive, - trans = object$trans, - values = object$values, - label = object$label - ) - } else { - cli::cli_abort( - "{.arg object} should be a {.cls quant_param} object, - not {.obj_type_friendly {object}}.", - call = call + + object <- + new_quant_param( + type = object$type, + range = range, + inclusive = object$inclusive, + trans = object$trans, + values = object$values, + label = object$label ) - } + object } diff --git a/tests/testthat/_snaps/aaa_ranges.md b/tests/testthat/_snaps/aaa_ranges.md index 35179c12..acc4bc8d 100644 --- a/tests/testthat/_snaps/aaa_ranges.md +++ b/tests/testthat/_snaps/aaa_ranges.md @@ -1,3 +1,35 @@ +# `range_validate()` checks inputs + + Code + range_validate("not a param", c(1, 10)) + Condition + Error: + ! `object` must be a object, not the string "not a param". + +--- + + Code + range_validate(penalty(), c(1, 10), ukn_ok = "maybe") + Condition + Error: + ! `ukn_ok` must be `TRUE` or `FALSE`, not the string "maybe". + +# `range_get()` checks inputs + + Code + range_get("not a param") + Condition + Error in `range_get()`: + ! `object` must be a object, not the string "not a param". + +--- + + Code + range_get(penalty(), original = "yes") + Condition + Error in `range_get()`: + ! `original` must be `TRUE` or `FALSE`, not the string "yes". + # setting ranges Code @@ -12,7 +44,7 @@ range_set(activation(), 1:2) Condition Error: - ! `object` should be a object, not a object. + ! `object` must be a object, not a object. --- diff --git a/tests/testthat/test-aaa_ranges.R b/tests/testthat/test-aaa_ranges.R index 150f6483..9e8ef52e 100644 --- a/tests/testthat/test-aaa_ranges.R +++ b/tests/testthat/test-aaa_ranges.R @@ -25,6 +25,19 @@ test_that("transforms", { ) }) +test_that("`range_validate()` checks inputs", { + expect_snapshot(error = TRUE, range_validate("not a param", c(1, 10))) + expect_snapshot( + error = TRUE, + range_validate(penalty(), c(1, 10), ukn_ok = "maybe") + ) +}) + +test_that("`range_get()` checks inputs", { + expect_snapshot(error = TRUE, range_get("not a param")) + expect_snapshot(error = TRUE, range_get(penalty(), original = "yes")) +}) + test_that("setting ranges", { expect_equal( range_set(mtry(), c(5L, 10L))$range, From 8f7ac1f1b339638df36d7c163fadbd02c75d361e Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 28 Jan 2026 13:57:36 +0000 Subject: [PATCH 9/9] `encode_unit()`: improve input checks --- R/encode_unit.R | 6 +++++- tests/testthat/_snaps/encode_unit.md | 10 +++++++++- tests/testthat/test-encode_unit.R | 8 ++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/R/encode_unit.R b/R/encode_unit.R index f4418577..07e95361 100644 --- a/R/encode_unit.R +++ b/R/encode_unit.R @@ -20,12 +20,16 @@ encode_unit <- function(x, value, direction, ...) { #' @export encode_unit.default <- function(x, value, direction, ...) { - cli::cli_abort("{.arg x} should be a dials parameter object.") + cli::cli_abort( + "{.arg x} should be a dials parameter object, + not {.obj_type_friendly x}." + ) } #' @rdname encode_unit #' @export encode_unit.quant_param <- function(x, value, direction, original = TRUE, ...) { + check_bool(original) if (has_unknowns(x)) { cli::cli_abort("The parameter object contains unknowns.") } diff --git a/tests/testthat/_snaps/encode_unit.md b/tests/testthat/_snaps/encode_unit.md index e5268159..71689298 100644 --- a/tests/testthat/_snaps/encode_unit.md +++ b/tests/testthat/_snaps/encode_unit.md @@ -22,13 +22,21 @@ Error in `encode_unit()`: ! `value` should be a numeric vector. +# encode_unit validates original argument + + Code + encode_unit(x, 0.5, direction = "backward", original = "yes") + Condition + Error in `encode_unit()`: + ! `original` must be `TRUE` or `FALSE`, not the string "yes". + # bad args Code encode_unit(2, prune_method()$values, direction = "forward") Condition Error in `encode_unit()`: - ! `x` should be a dials parameter object. + ! `x` should be a dials parameter object, not a string. --- diff --git a/tests/testthat/test-encode_unit.R b/tests/testthat/test-encode_unit.R index 1ef079ed..2acf6a0e 100644 --- a/tests/testthat/test-encode_unit.R +++ b/tests/testthat/test-encode_unit.R @@ -56,6 +56,14 @@ test_that("missing data", { }) +test_that("encode_unit validates original argument", { + x <- mtry(c(2L, 7L)) + expect_snapshot( + error = TRUE, + encode_unit(x, 0.5, direction = "backward", original = "yes") + ) +}) + test_that("bad args", { x <- mtry(c(2L, 7L)) z <- prune_method()