diff --git a/DESCRIPTION b/DESCRIPTION index 1bd510e7..0308400d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -47,6 +47,7 @@ Collate: 'accelerator.R' 'utils.R' 'callbacks.R' + 'callback-validation-check.R' 'callbacks-interrupt.R' 'callbacks-profile.R' 'context.R' diff --git a/NAMESPACE b/NAMESPACE index 25471b33..f003842d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -19,6 +19,7 @@ export(luz_callback_model_checkpoint) export(luz_callback_profile) export(luz_callback_progress) export(luz_callback_train_valid) +export(luz_callback_validation_check) export(luz_load) export(luz_load_model_weights) export(luz_metric) diff --git a/R/callback-validation-check.R b/R/callback-validation-check.R new file mode 100644 index 00000000..c788975f --- /dev/null +++ b/R/callback-validation-check.R @@ -0,0 +1,74 @@ +#' @include callbacks.R +NULL + +#' Validation Check +#' +#' Check validation loop before fitting model. +#' +#' @param batches Number of validation batches to check. Default is 2. +#' +#' @note Usually the training loop is much longer than the validation +#' loop and issues with the validation loop aren't encountered until after +#' a long training runtime. This callback runs the validation loop first on +#' `batches` number of batches and then proceeds onto the standard +#' training process. +#' +#' @note Printing can be disabled by passing `verbose = FALSE` to +#' [fit.luz_module_generator()]. +#' +#' @family luz_callbacks +#' +#' @returns +#' A `luz_callback`. +#' +#' @export +luz_callback_validation_check <- luz_callback( + "validation_check_callback", + initialize = function(batches = 2) { + if (!rlang::is_scalar_integerish(batches)) { + rlang::abort("`batches` must be a single integer value.") + } + self$batches <- batches + }, + on_fit_begin = function() { + if (is.null(ctx$valid_data)) return() + if (self$batches <= 0) return() + + ctx$model$eval() + ctx$training <- FALSE + + self$initialize_progress_bar() + + i <- 0 + torch::with_no_grad({ + coro::loop(for (batch in ctx$valid_data) { + self$validate_one_batch(batch) + self$tick_progress_bar(self$loss) + i <- i + 1 + if (i >= self$batches) break() + }) + }) + }, + validate_one_batch = function(batch) { + input <- list(batch[[1]]) + target <- batch[[2]] + pred <- do.call(ctx$model, input) + self$loss <- ctx$model$loss(pred, target) + }, + initialize_progress_bar = function() { + format <- "Validation check: :current/:total [:bar] - Loss: :loss" + self$pb <- progress::progress_bar$new( + force = getOption("luz.force_progress_bar", FALSE), + show_after = 0, + format = format, + total = self$batches, + clear = FALSE + ) + }, + tick_progress_bar = function(token) { + if (ctx$verbose) { + loss <- format(round(as.numeric(token), digits = 4), nsmall = 4) + self$pb$tick(tokens = list(loss = loss)) + } + } +) diff --git a/R/module.R b/R/module.R index e2f2d4e4..93b7c741 100644 --- a/R/module.R +++ b/R/module.R @@ -413,8 +413,9 @@ clean_context <- function(ctx) { "pred", "opt", "opt_name", - "data", "handlers", + "data", + "train_data", "valid_data", "loss", "input", diff --git a/man/ctx.Rd b/man/ctx.Rd index 298d7bb8..28f03b58 100644 --- a/man/ctx.Rd +++ b/man/ctx.Rd @@ -20,7 +20,8 @@ could potentially modify these attributes or add new ones.\tabular{ll}{ \code{data} \tab Current in use dataloader. When training it’s \code{ctx$train_data}, when doing validation its \code{ctx$valid_data}. It can also be the prediction dataset when in \code{predict}. \cr \code{train_data} \tab Dataloader passed to the \code{data} argument in \code{fit}. Modified to yield data in the selected device. \cr \code{valid_data} \tab Dataloader passed to the \code{valid_data} argument in \code{fit}. Modified to yield data in the selected device. \cr - \code{epochs} \tab Total number of epochs the model will be trained on. \cr + \code{min_epochs} \tab | Minimum number of epochs the model will be trained for. \cr + \code{max_epochs} \tab | Maximum number of epochs the model will be trained for. \cr \code{epoch} \tab Current training epoch. \cr \code{iter} \tab Current training iteration. It’s reset every epoch and when going from training to validation. \cr \code{training} \tab Whether the model is in training or validation mode. See also \code{help("luz_callback_train_valid")} \cr diff --git a/man/luz_callback.Rd b/man/luz_callback.Rd index ffdb54b4..8d270935 100644 --- a/man/luz_callback.Rd +++ b/man/luz_callback.Rd @@ -142,6 +142,7 @@ Other luz_callbacks: \code{\link{luz_callback_model_checkpoint}()}, \code{\link{luz_callback_profile}()}, \code{\link{luz_callback_progress}()}, -\code{\link{luz_callback_train_valid}()} +\code{\link{luz_callback_train_valid}()}, +\code{\link{luz_callback_validation_check}()} } \concept{luz_callbacks} diff --git a/man/luz_callback_csv_logger.Rd b/man/luz_callback_csv_logger.Rd index 941d5af1..eb7759ea 100644 --- a/man/luz_callback_csv_logger.Rd +++ b/man/luz_callback_csv_logger.Rd @@ -23,6 +23,7 @@ Other luz_callbacks: \code{\link{luz_callback_profile}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback_train_valid}()}, +\code{\link{luz_callback_validation_check}()}, \code{\link{luz_callback}()} } \concept{luz_callbacks} diff --git a/man/luz_callback_early_stopping.Rd b/man/luz_callback_early_stopping.Rd index 9c11e4b9..afdf4a70 100644 --- a/man/luz_callback_early_stopping.Rd +++ b/man/luz_callback_early_stopping.Rd @@ -56,6 +56,7 @@ Other luz_callbacks: \code{\link{luz_callback_profile}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback_train_valid}()}, +\code{\link{luz_callback_validation_check}()}, \code{\link{luz_callback}()} } \concept{luz_callbacks} diff --git a/man/luz_callback_interrupt.Rd b/man/luz_callback_interrupt.Rd index 5d7c5369..96169847 100644 --- a/man/luz_callback_interrupt.Rd +++ b/man/luz_callback_interrupt.Rd @@ -32,6 +32,7 @@ Other luz_callbacks: \code{\link{luz_callback_profile}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback_train_valid}()}, +\code{\link{luz_callback_validation_check}()}, \code{\link{luz_callback}()} } \concept{luz_callbacks} diff --git a/man/luz_callback_lr_scheduler.Rd b/man/luz_callback_lr_scheduler.Rd index 7ac57e29..5e1e9933 100644 --- a/man/luz_callback_lr_scheduler.Rd +++ b/man/luz_callback_lr_scheduler.Rd @@ -46,6 +46,7 @@ Other luz_callbacks: \code{\link{luz_callback_profile}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback_train_valid}()}, +\code{\link{luz_callback_validation_check}()}, \code{\link{luz_callback}()} } \concept{luz_callbacks} diff --git a/man/luz_callback_metrics.Rd b/man/luz_callback_metrics.Rd index 2a70cb77..c5f87924 100644 --- a/man/luz_callback_metrics.Rd +++ b/man/luz_callback_metrics.Rd @@ -36,6 +36,7 @@ Other luz_callbacks: \code{\link{luz_callback_profile}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback_train_valid}()}, +\code{\link{luz_callback_validation_check}()}, \code{\link{luz_callback}()} } \concept{luz_callbacks} diff --git a/man/luz_callback_model_checkpoint.Rd b/man/luz_callback_model_checkpoint.Rd index e0945fde..4571a4c1 100644 --- a/man/luz_callback_model_checkpoint.Rd +++ b/man/luz_callback_model_checkpoint.Rd @@ -59,6 +59,7 @@ Other luz_callbacks: \code{\link{luz_callback_profile}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback_train_valid}()}, +\code{\link{luz_callback_validation_check}()}, \code{\link{luz_callback}()} } \concept{luz_callbacks} diff --git a/man/luz_callback_profile.Rd b/man/luz_callback_profile.Rd index 04cf8119..765ec9ce 100644 --- a/man/luz_callback_profile.Rd +++ b/man/luz_callback_profile.Rd @@ -42,6 +42,7 @@ Other luz_callbacks: \code{\link{luz_callback_model_checkpoint}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback_train_valid}()}, +\code{\link{luz_callback_validation_check}()}, \code{\link{luz_callback}()} } \concept{luz_callbacks} diff --git a/man/luz_callback_progress.Rd b/man/luz_callback_progress.Rd index b0845df7..155a0565 100644 --- a/man/luz_callback_progress.Rd +++ b/man/luz_callback_progress.Rd @@ -28,6 +28,7 @@ Other luz_callbacks: \code{\link{luz_callback_model_checkpoint}()}, \code{\link{luz_callback_profile}()}, \code{\link{luz_callback_train_valid}()}, +\code{\link{luz_callback_validation_check}()}, \code{\link{luz_callback}()} } \concept{luz_callbacks} diff --git a/man/luz_callback_train_valid.Rd b/man/luz_callback_train_valid.Rd index f2b09439..226ed1c0 100644 --- a/man/luz_callback_train_valid.Rd +++ b/man/luz_callback_train_valid.Rd @@ -37,6 +37,7 @@ Other luz_callbacks: \code{\link{luz_callback_model_checkpoint}()}, \code{\link{luz_callback_profile}()}, \code{\link{luz_callback_progress}()}, +\code{\link{luz_callback_validation_check}()}, \code{\link{luz_callback}()} } \concept{luz_callbacks} diff --git a/man/luz_callback_validation_check.Rd b/man/luz_callback_validation_check.Rd new file mode 100644 index 00000000..8d3c5377 --- /dev/null +++ b/man/luz_callback_validation_check.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/callback-validation-check.R +\name{luz_callback_validation_check} +\alias{luz_callback_validation_check} +\title{Validation Check} +\usage{ +luz_callback_validation_check(batches = 2) +} +\arguments{ +\item{batches}{Number of validation batches to check. Default is 2.} +} +\value{ +A \code{luz_callback}. +} +\description{ +Check validation loop before fitting model. +} +\note{ +Usually the training loop is much longer than the validation +loop and issues with the validation loop aren't encountered until after +a long training runtime. This callback runs the validation loop first on +\code{batches} number of batches and then proceeds onto the standard +training process. + +Printing can be disabled by passing \code{verbose = FALSE} to +\code{\link[=fit.luz_module_generator]{fit.luz_module_generator()}}. +} +\seealso{ +Other luz_callbacks: +\code{\link{luz_callback_csv_logger}()}, +\code{\link{luz_callback_early_stopping}()}, +\code{\link{luz_callback_interrupt}()}, +\code{\link{luz_callback_lr_scheduler}()}, +\code{\link{luz_callback_metrics}()}, +\code{\link{luz_callback_model_checkpoint}()}, +\code{\link{luz_callback_profile}()}, +\code{\link{luz_callback_progress}()}, +\code{\link{luz_callback_train_valid}()}, +\code{\link{luz_callback}()} +} +\concept{luz_callbacks}