Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,6 @@ make_x_call <- function(object, target) {
fit_call
}

make_form_call <- function(object, env = NULL) {
fit_args <- object$method$fit$args

# Get the arguments related to data:
if (is.null(object$method$fit$data)) {
data_args <- c(formula = "formula", data = "data")
} else {
data_args <- object$method$fit$data
}

# add data arguments
for (i in seq_along(data_args)) {
fit_args[[unname(data_args[i])]] <- sym(names(data_args)[i])
}

# sub in actual formula
fit_args[[unname(data_args["formula"])]] <- env$formula

fit_call <- make_call(
fun = object$method$fit$func["fun"],
ns = object$method$fit$func["pkg"],
fit_args
)
fit_call
}

#' Change arguments of a cluster specification
#'
#' @inheritParams parsnip::set_args
Expand Down
84 changes: 5 additions & 79 deletions R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
#' internals of `lm()` (and also see the notes at
#' https://developer.r-project.org/model-fitting-functions.html).
#'
#' `.convert_form_to_x_fit()` and `.convert_x_to_form_fit()` are for when the
#' data are created for modeling. `.convert_form_to_x_fit()` saves both the
#' data objects as well as the objects needed when new data are predicted
#' (e.g. `terms`, etc.).
#' `.convert_form_to_x_fit()` is for when the data are created for modeling.
#' It saves both the data objects as well as the objects needed when new data
#' are predicted (e.g. `terms`, etc.).
#'
#' `.convert_form_to_x_new()` and `.convert_x_to_form_new()` are used when new
#' samples are being predicted and only require the predictors to be
#' available.
#' `.convert_form_to_x_new()` is used when new samples are being predicted and
#' only requires the predictors to be available.
#'
#' @param data A data frame containing all relevant variables (e.g. predictors,
#' case weights, etc).
Expand Down Expand Up @@ -144,68 +142,6 @@ local_one_hot_contrasts <- function(frame = rlang::caller_env()) {
rlang::local_options(contrasts = contrasts, .frame = frame)
}

# ------------------------------------------------------------------------------

# The other direction where we make a formula from the data
# objects

# TODO slots for other roles
#' @param weights A numeric vector containing the weights.
#' @inheritParams fit.cluster_spec
#' @inheritParams .convert_form_to_x_fit
#' @rdname convert_helpers
#' @keywords internal
.convert_x_to_form_fit <- function(x, weights = NULL, remove_intercept = TRUE) {
if (is.vector(x)) {
cli::cli_abort("{.arg x} cannot be a vector.")
}

if (remove_intercept) {
x <- x[, colnames(x) != "(Intercept)", drop = FALSE]
}

rn <- rownames(x)

if (!is.data.frame(x)) {
x <- as.data.frame(x)
}

x_var <- names(x)
form <- make_formula(names(x))

x <- bind_cols(x, y)
if (!is.null(rn) && !inherits(x, "tbl_df")) {
rownames(x) <- rn
}

if (!is.null(weights)) {
if (!is.numeric(weights)) {
cli::cli_abort("The {.arg weights} must be a numeric vector.")
}
if (length(weights) != nrow(x)) {
cli::cli_abort("{.arg weights} should have {nrow(x)} elements.")
}
}

res <- list(
formula = form,
data = x,
weights = weights,
x_var = x_var
)
res
}

make_formula <- function(x, short = TRUE) {
y_part <- "~"
if (short) {
form_text <- paste0(y_part, ".")
} else {
form_text <- paste0(y_part, paste0(x, collapse = "+"))
}
as.formula(form_text)
}

#' @param object An object of class [`cluster_fit`].
#' @inheritParams predict.cluster_fit
#' @rdname convert_helpers
Expand Down Expand Up @@ -262,13 +198,3 @@ make_formula <- function(x, short = TRUE) {
}
list(x = new_data, offset = offset)
}

#' @rdname convert_helpers
#' @keywords internal
.convert_x_to_form_new <- function(object, new_data) {
new_data <- new_data[, object$x_var, drop = FALSE]
if (!is.data.frame(new_data)) {
new_data <- as.data.frame(new_data)
}
new_data
}
2 changes: 2 additions & 0 deletions R/engine_docs.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# nocov start
# https://github.com/tidymodels/parsnip/blob/main/R/engine_docs.R

#' Knit engine-specific documentation
Expand Down Expand Up @@ -64,3 +65,4 @@ list_md_problems <- function() {

map(md_files, get_errors) |> vctrs::vec_rbind()
}
# nocov end
109 changes: 34 additions & 75 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,48 +120,29 @@ fit.cluster_spec <- function(

eval_env$data <- data
eval_env$formula <- formula
fit_interface <-
check_interface(eval_env$formula, eval_env$data, cl, object)
check_interface(eval_env$formula, eval_env$data, cl, object)

# populate `method` with the details for this model type
object <- add_methods(object, engine = object$engine)

check_installs(object)

interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_")

# Now call the wrappers that transition between the interface
# called here ("fit" interface) that will direct traffic to
# what the underlying model uses. For example, if a formula is
# used here, `fit_interface_formula` will determine if a
# translation has to be made if the model interface is x/y/
res <-
switch(
interfaces,
# homogeneous combinations:
formula_formula = form_form(
object = object,
control = control,
env = eval_env
),

# heterogenous combinations
formula_matrix = form_x(
object = object,
control = control,
env = eval_env,
target = object$method$fit$interface,
...
),
formula_data.frame = form_x(
object = object,
control = control,
env = eval_env,
target = object$method$fit$interface,
...
),
cli::cli_abort("{interfaces} is unknown.")
interface <- object$method$fit$interface
if (!interface %in% c("matrix", "data.frame")) {
# nocov start
cli::cli_abort(
"Column interface {.val {interface}} is not supported.",
.internal = TRUE
)
} # nocov end

res <- form_x(
object = object,
control = control,
env = eval_env,
target = interface,
...
)
model_classes <- class(res$fit)
class(res) <- c(paste0("_", model_classes[1]), "cluster_fit")
res <- modelenv::new_unsupervised_fit(res)
Expand All @@ -177,7 +158,7 @@ check_interface <- function(formula, data, cl, model) {
if (form_interface) {
return("formula")
}
cli::cli_abort("Error when checking the interface.")
cli::cli_abort("Error when checking the interface.", .internal = TRUE) # nocov
}

inher <- function(x, cls, cl) {
Expand Down Expand Up @@ -249,51 +230,29 @@ fit_xy.cluster_spec <-
cl <- match.call(expand.dots = TRUE)
eval_env <- rlang::env()
eval_env$x <- x
fit_interface <- check_x_interface(eval_env$x, cl, object)
check_x_interface(eval_env$x, cl, object)

# populate `method` with the details for this model type
object <- add_methods(object, engine = object$engine)

check_installs(object)

interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_")

# Now call the wrappers that transition between the interface
# called here ("fit" interface) that will direct traffic to
# what the underlying model uses. For example, if a formula is
# used here, `fit_interface_formula` will determine if a
# translation has to be made if the model interface is x/y/
res <-
switch(
interfaces,
# homogeneous combinations:
matrix_matrix = ,
data.frame_matrix = x_x(
object = object,
env = eval_env,
control = control,
target = "matrix",
...
),
data.frame_data.frame = ,
matrix_data.frame = x_x(
object = object,
env = eval_env,
control = control,
target = "data.frame",
...
),

# heterogenous combinations
matrix_formula = ,
data.frame_formula = x_form(
object = object,
env = eval_env,
control = control,
...
),
cli::cli_abort("{interfaces} is unknown.")
interface <- object$method$fit$interface
if (!interface %in% c("matrix", "data.frame")) {
# nocov start
cli::cli_abort(
"Column interface {.val {interface}} is not supported.",
.internal = TRUE
)
} # nocov end

res <- x_x(
object = object,
env = eval_env,
control = control,
target = interface,
...
)
model_classes <- class(res$fit)
class(res) <- c(paste0("_", model_classes[1]), "cluster_fit")
res
Expand Down Expand Up @@ -328,7 +287,7 @@ check_x_interface <- function(x, cl, model) {
if (df_interface) {
return("data.frame")
}
cli::cli_abort("Error when checking the interface")
cli::cli_abort("Error when checking the interface", .internal = TRUE) # nocov
}

allow_sparse <- function(x) {
Expand Down
65 changes: 0 additions & 65 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
@@ -1,43 +1,5 @@
# https://github.com/tidymodels/parsnip/blob/main/R/fit_helpers.R

form_form <- function(object, control, env, ...) {
# evaluate quoted args once here to check them
object <- check_args(object)

# sub in arguments to actual syntax for corresponding engine
object <- translate_tidyclust(object, engine = object$engine)

fit_call <- make_form_call(object, env = env)

res <- list(
spec = object
)

if (control$verbosity > 1L) {
elapsed <- system.time(
res$fit <- eval_mod(
fit_call,
capture = control$verbosity == 0,
catch = control$catch,
env = env,
...
),
gcFirst = FALSE
)
} else {
res$fit <- eval_mod(
fit_call,
capture = control$verbosity == 0,
catch = control$catch,
env = env,
...
)
elapsed <- list(elapsed = NA_real_)
}
res$elapsed <- elapsed
res
}

form_x <- function(object, control, env, target = "none", ...) {
encoding_info <-
modelenv::get_encoding(class(object)[1]) |>
Expand Down Expand Up @@ -116,30 +78,3 @@ x_x <- function(object, env, control, target = "none", y = NULL, ...) {
res$elapsed <- elapsed
res
}

x_form <- function(object, env, control, ...) {
encoding_info <-
modelenv::get_encoding(class(object)[1]) |>
dplyr::filter(mode == object$mode, engine == object$engine)

remove_intercept <- encoding_info |> dplyr::pull(remove_intercept)

data_obj <-
.convert_x_to_form_fit(
x = env$x,
weights = NULL,
remove_intercept = remove_intercept
)
env$formula <- data_obj$formula
env$data <- data_obj$data

# which terms etc goes in the preproc slot here?
res <- form_form(
object = object,
env = env,
control = control,
...
)
res$preproc <- data_obj[c("x_var")]
res
}
Loading