Skip to content

Rewrite of training, callbacks, losses, optimizers and batchers for more flexibility.#72

Draft
jaosch wants to merge 9 commits intodevelopfrom
feature/modularize-training
Draft

Rewrite of training, callbacks, losses, optimizers and batchers for more flexibility.#72
jaosch wants to merge 9 commits intodevelopfrom
feature/modularize-training

Conversation

@jaosch
Copy link
Collaborator

@jaosch jaosch commented Sep 13, 2025

The main change is a new fit_core function, which implements a training loop, which is dependent only on a batcher, an updater, and a TrainingContext.

The newly added TrainingContext includes the functionality of the former CallbackArgs, but is further subdivided internally, to separate data and losses from model and optimizer states. The TrainingContext also contains a TimingInfo object, which provides basic information about the time of each step.

In order to enable custom batching with dependence on the TrainingContext, the latter is sent to the batcher withing the training loop. This way, batching could, e.g., be made dependent on time or on the step.

The interface of the fitmethod remains largely unchanged, exposing only some of the flexibility of fit_core to the user.

To make loss evaluation and model flattening and unflattening more efficient, both TrainingState and EvaluationContext implement caches, as previouly implemented for CallbackArgs.

This PR is currently a draft, as the tests have not been updated to work with the new API.

TODO:

  • Update tests
  • Update docstrings
  • Update documentation
  • Update examples

The `fit` function is modualized.

The main loop is performed within the `_fit_core` function, which knows
nothing about the data, optimizer, model, and loss. It only performs the
update.

The original API of `fit` is provided as a wrapper, which exposes some
parameters of the training components to the user, e.g., batch axis.

Different `Protocols` are implemented for the loss, the updater and the
batcher, but none of them create new layers of abstraction, i.e. not new
classes, but mere serve as templates for the different functions
required for training.

Note: Callbacks have not been implemented in this approach yet.
An example was added at the bottom of the module to show-case the
flexibility of the `fit_core` method for building custom trainings.
The protocol for BatchGenerators has been changed to allow for a
`TrainingState` to be sent to the generator. This enables statefull
batching, where, e.g., the step of the model can be used as a random
seed for jandom number generation.

Additionally, the task of flattening and unflattening the model was used
to the `TrainingState', making the `fit_core` function even more
minimal.
The `EvaluationContext` takes care of the lazy loss evaluation. It
receives the training state as input and stores the loss and validation
loss in a cache. The cache is updated whenever the training step
changed.

The `TrainingContext` is a simple wrapper around a `TrainingState` and
an `EvaluatorContext`. It serves as the main input to every `Callback`.

This new implementation nicely separates the `dump` training state from
the data and the evaluation.
The dependency on the training step, which was introduced as a test has
been removed.
Comment on lines +137 to +138
if init_opt_state is None
else init_opt_state,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants