Rewrite of training, callbacks, losses, optimizers and batchers for more flexibility.#72
Draft
Rewrite of training, callbacks, losses, optimizers and batchers for more flexibility.#72
Conversation
The `fit` function is modualized. The main loop is performed within the `_fit_core` function, which knows nothing about the data, optimizer, model, and loss. It only performs the update. The original API of `fit` is provided as a wrapper, which exposes some parameters of the training components to the user, e.g., batch axis. Different `Protocols` are implemented for the loss, the updater and the batcher, but none of them create new layers of abstraction, i.e. not new classes, but mere serve as templates for the different functions required for training. Note: Callbacks have not been implemented in this approach yet.
An example was added at the bottom of the module to show-case the flexibility of the `fit_core` method for building custom trainings.
The protocol for BatchGenerators has been changed to allow for a `TrainingState` to be sent to the generator. This enables statefull batching, where, e.g., the step of the model can be used as a random seed for jandom number generation. Additionally, the task of flattening and unflattening the model was used to the `TrainingState', making the `fit_core` function even more minimal.
The `EvaluationContext` takes care of the lazy loss evaluation. It receives the training state as input and stores the loss and validation loss in a cache. The cache is updated whenever the training step changed. The `TrainingContext` is a simple wrapper around a `TrainingState` and an `EvaluatorContext`. It serves as the main input to every `Callback`. This new implementation nicely separates the `dump` training state from the data and the evaluation.
The dependency on the training step, which was introduced as a test has been removed.
Drenderer
reviewed
Sep 15, 2025
Comment on lines
+137
to
+138
| if init_opt_state is None | ||
| else init_opt_state, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The main change is a new
fit_corefunction, which implements a training loop, which is dependent only on a batcher, an updater, and aTrainingContext.The newly added
TrainingContextincludes the functionality of the formerCallbackArgs, but is further subdivided internally, to separate data and losses from model and optimizer states. TheTrainingContextalso contains aTimingInfoobject, which provides basic information about the time of each step.In order to enable custom batching with dependence on the
TrainingContext, the latter is sent to the batcher withing the training loop. This way, batching could, e.g., be made dependent on time or on the step.The interface of the
fitmethod remains largely unchanged, exposing only some of the flexibility offit_coreto the user.To make loss evaluation and model flattening and unflattening more efficient, both
TrainingStateandEvaluationContextimplement caches, as previouly implemented forCallbackArgs.This PR is currently a draft, as the tests have not been updated to work with the new API.
TODO: