diff --git a/NEWS.md b/NEWS.md index 5ae5e18b..2c404548 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,8 @@ * `prop_terms()` is a new parameter object used for recipes that do supervised feature selection (#395). +* `batch_size()` now has a specific default parameter range instead of an unknown default range. `get_batch_sizes()` is deprecated (#398). + # dials 1.4.1 * Two new parameters, `cal_method_class()` and `cal_method_reg(),` to control which method is used to calibrate model predictions (#383). diff --git a/R/finalize.R b/R/finalize.R index e1fa613f..347ef9d5 100644 --- a/R/finalize.R +++ b/R/finalize.R @@ -273,9 +273,16 @@ get_rbf_range <- function(object, x, seed = sample.int(10^5, 1), ...) { range_set(object, rng) } +#' Get batch sizes +#' +#' `r lifecycle::badge("deprecated")` +#' +#' @inheritParams finalize +#' @keywords internal #' @export -#' @rdname finalize get_batch_sizes <- function(object, x, frac = c(1 / 10, 1 / 3), ...) { + lifecycle::deprecate_warn("1.4.2", "get_batch_sizes()") + rngs <- range_get(object, original = FALSE) if (!is_unknown(rngs$lower) & !is_unknown(rngs$upper)) { return(object) diff --git a/R/param_network.R b/R/param_network.R index 64100445..5d37b01c 100644 --- a/R/param_network.R +++ b/R/param_network.R @@ -67,7 +67,7 @@ hidden_units_2 <- function(range = c(1L, 10L), trans = NULL) { #' @export #' @rdname dropout batch_size <- function( - range = c(unknown(), unknown()), + range = c(2L, 7L), trans = transform_log2() ) { new_quant_param( @@ -76,6 +76,6 @@ batch_size <- function( inclusive = c(TRUE, TRUE), trans = trans, label = c(batch_size = "Batch Size"), - finalize = get_batch_sizes + finalize = NULL ) } diff --git a/man/dropout.Rd b/man/dropout.Rd index fcb14d83..9026d276 100644 --- a/man/dropout.Rd +++ b/man/dropout.Rd @@ -16,7 +16,7 @@ hidden_units(range = c(1L, 10L), trans = NULL) hidden_units_2(range = c(1L, 10L), trans = NULL) -batch_size(range = c(unknown(), unknown()), trans = transform_log2()) +batch_size(range = c(2L, 7L), trans = transform_log2()) } \arguments{ \item{range}{A two-element vector holding the \emph{defaults} for the smallest and diff --git a/man/finalize.Rd b/man/finalize.Rd index c7d6a442..3958c318 100644 --- a/man/finalize.Rd +++ b/man/finalize.Rd @@ -13,7 +13,6 @@ \alias{get_n_frac_range} \alias{get_n} \alias{get_rbf_range} -\alias{get_batch_sizes} \title{Functions to finalize data-specific parameter ranges} \usage{ finalize(object, ...) @@ -39,8 +38,6 @@ get_n_frac_range(object, x, log_vals = FALSE, frac = c(1/10, 5/10), ...) get_n(object, x, log_vals = FALSE, ...) get_rbf_range(object, x, seed = sample.int(10^5, 1), ...) - -get_batch_sizes(object, x, frac = c(1/10, 1/3), ...) } \arguments{ \item{object}{A \code{param} object or a list of \code{param} objects.} diff --git a/man/get_batch_sizes.Rd b/man/get_batch_sizes.Rd new file mode 100644 index 00000000..ceeaf181 --- /dev/null +++ b/man/get_batch_sizes.Rd @@ -0,0 +1,26 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/finalize.R +\name{get_batch_sizes} +\alias{get_batch_sizes} +\title{Get batch sizes} +\usage{ +get_batch_sizes(object, x, frac = c(1/10, 1/3), ...) +} +\arguments{ +\item{object}{A \code{param} object or a list of \code{param} objects.} + +\item{x}{The predictor data. In some cases (see below) this should only +include numeric data.} + +\item{frac}{A double for the fraction of the data to be used for the upper +bound. For \code{get_n_frac_range()} and \code{get_batch_sizes()}, a vector of two +fractional values are required.} + +\item{...}{Other arguments to pass to the underlying parameter +finalizer functions. For example, for \code{get_rbf_range()}, the dots are passed +along to \code{\link[kernlab:sigest]{kernlab::sigest()}}.} +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} +} +\keyword{internal} diff --git a/tests/testthat/_snaps/finalize.md b/tests/testthat/_snaps/finalize.md index d8ed570f..11186691 100644 --- a/tests/testthat/_snaps/finalize.md +++ b/tests/testthat/_snaps/finalize.md @@ -46,6 +46,14 @@ Error in `get_n_frac()`: ! Cannot determine number of columns. Is `x` a 2D data object? +# `get_batch_size() is deprecated + + Code + bsizes <- get_batch_sizes(batch_size(), iris, frac = c(0.3, 0.7)) + Condition + Warning: + `get_batch_sizes()` was deprecated in dials 1.4.2. + # estimate sigma Code diff --git a/tests/testthat/test-finalize.R b/tests/testthat/test-finalize.R index a221bcdd..790df1c9 100644 --- a/tests/testthat/test-finalize.R +++ b/tests/testthat/test-finalize.R @@ -54,12 +54,33 @@ test_that("estimate rows", { list(lower = 16, upper = 37) ) + expect_equal(get_n_frac(mtry(c(1, 2)), mtcars), mtry(c(1, 2))) +}) + +test_that("`get_batch_size() is deprecated", { + expect_snapshot( + bsizes <- get_batch_sizes(batch_size(), iris, frac = c(.3, .7)) + ) +}) + +test_that("`get_batch_size() works", { + withr::local_options(lifecycle_verbosity = "quiet") + mock_batch_size_with_unknown <- new_quant_param( + type = "integer", + range = c(unknown(), unknown()), + inclusive = c(TRUE, TRUE), + trans = transform_log2(), + label = c(batch_size = "Batch Size"), + finalize = get_batch_sizes + ) expect_equal( - get_batch_sizes(batch_size(), iris, frac = c(.3, .7))$range, + get_batch_sizes( + mock_batch_size_with_unknown, + iris, + frac = c(.3, .7) + )$range, list(lower = log2(45), upper = log2(105)) ) - - expect_equal(get_n_frac(mtry(c(1, 2)), mtcars), mtry(c(1, 2))) })