Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ export(activation)
export(activation_2)
export(adjust_deg_free)
export(all_neighbors)
export(average_before_softmax)
export(balance_probabilities)
export(batch_size)
export(buffer)
export(cal_method_class)
Expand Down Expand Up @@ -119,6 +121,7 @@ export(no_global_pruning)
export(num_breaks)
export(num_clusters)
export(num_comp)
export(num_estimators)
export(num_hash)
export(num_knots)
export(num_leaves)
Expand Down Expand Up @@ -172,6 +175,7 @@ export(shrinkage_variance)
export(signed_hash)
export(significance_threshold)
export(smoothness)
export(softmax_temperature)
export(spline_degree)
export(splitting_rule)
export(stop_iter)
Expand All @@ -182,6 +186,7 @@ export(svm_margin)
export(target_weight)
export(threshold)
export(token)
export(training_set_limit)
export(tree_depth)
export(trees)
export(trim_amount)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

* A bug was fixed where some space-filling designs did not respect the `original` argument (#409).

* Parameters were added for the `tab_pfn` model: `num_estimators()`, `softmax_temperature()`, `balance_probabilities()`, `average_before_softmax()`, and `training_set_limit()`.

# dials 1.4.2

* `prop_terms()` is a new parameter object used for recipes that do supervised feature selection (#395).
Expand Down
2 changes: 1 addition & 1 deletion R/param_schedulers.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#' Parameters for neural network learning rate schedulers
#
#'
#' These parameters are used for constructing neural network models.
#'
#' @inheritParams Laplace
Expand Down
70 changes: 70 additions & 0 deletions R/param_tab_pfn.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#' Parameters for TabPFN models
#'
#' These parameters are used for constructing Prior data fitted network (TabPFN)
#' models.
#'
#' @inheritParams Laplace
#' @inheritParams select_features
#'
#' @details
#' These parameters are often used with TabPFN models via `parsnip::tab_pfn()`.
#' @name tab-pfn-param
#' @export
num_estimators <- function(range = c(1, 25), trans = NULL) {
new_quant_param(
type = "integer",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
label = c(num_estimators = "# Estimators"),
finalize = NULL
)
}

#' @rdname tab-pfn-param
#' @export
softmax_temperature <- function(range = c(0, 10), trans = NULL) {
new_quant_param(
type = "double",
range = range,
inclusive = c(FALSE, TRUE),
trans = trans,
label = c(softmax_temperature = "Softmax Temperature"),
finalize = NULL
)
}

#' @rdname tab-pfn-param
#' @export
balance_probabilities <- function(values = c(TRUE, FALSE)) {
new_qual_param(
type = "logical",
values = values,
label = c(balance_probabilities = "Balance Probabilities?"),
finalize = NULL
)
}

#' @rdname tab-pfn-param
#' @export
average_before_softmax <- function(values = c(TRUE, FALSE)) {
new_qual_param(
type = "logical",
values = values,
label = c(average_before_softmax = "Average Before Softmax?"),
finalize = NULL
)
}

#' @rdname tab-pfn-param
#' @export
training_set_limit <- function(range = c(2L, 10000L), trans = NULL) {
new_quant_param(
type = "integer",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
label = c(training_set_limit = "Training Set Size"),
finalize = NULL
)
}
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ reference:
- neighbors
- num_clusters
- num_comp
- num_estimators
- num_knots
- penalty
- predictor_prop
Expand Down
4 changes: 1 addition & 3 deletions man/scheduler-param.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 40 additions & 0 deletions man/tab-pfn-param.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/test-params.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ test_that("param ranges", {
expect_equal(mtry_prop(c(.1, .2))$range, list(lower = .1, upper = .2))
expect_equal(dropout(c(.1, .2))$range, list(lower = .1, upper = .2))
expect_equal(prop_terms(c(.1, .2))$range, list(lower = .1, upper = .2))
expect_equal(num_estimators(c(1L, 10L))$range, list(lower = 1L, upper = 10L))
expect_equal(
softmax_temperature(c(0.1, 2.0))$range,
list(lower = 0.1, upper = 2.0)
)
expect_equal(
training_set_limit(c(2L, 10L))$range,
list(lower = 2L, upper = 10L)
)
})


Expand Down Expand Up @@ -179,4 +188,6 @@ test_that("param values", {
expect_equal(all_neighbors(TRUE)$values, TRUE)
expect_equal(cal_method_class()$values, values_cal_cls)
expect_equal(cal_method_reg()$values, values_cal_reg)
expect_equal(balance_probabilities(TRUE)$values, TRUE)
expect_equal(average_before_softmax(TRUE)$values, TRUE)
})