Skip to content

Revised training core#81

Merged
Drenderer merged 49 commits intodevelopfrom
feature/generalized_training
Jan 28, 2026
Merged

Revised training core#81
Drenderer merged 49 commits intodevelopfrom
feature/generalized_training

Conversation

@Drenderer
Copy link
Owner

@Drenderer Drenderer commented Dec 1, 2025

Key improvements:

  • Split training ingredients into state (model and optimizer state) and static (optimizer, batcher, loss, etc.)
  • Defined a new core training_loop function that only takes state, static, and callbacks as arguments. The loop updates the state.
  • Removed CallbackArgs in favor of a TrainingView, which provides read-only access to static and read-write access to the state. This greatly simplifies the callbacks.
  • Defined an ABC for Losses, which allows for custom gradient computations and implements model unwrapping by default (which avoids a redefinition of the loss in the fit function).
  • Refactored datahandler.
  • Removed the HistoryCallback in favor of a more general logging framework using a MetricLogger callback and a dict-like History object. This significantly improves the separation of concerns and enables the user to easily track custom metrics.
  • Refactored klax.fit to recreate the old behaviour with the new components.

@Drenderer Drenderer marked this pull request as draft December 1, 2025 05:52
Copy link
Collaborator

@jaosch jaosch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool changes. The main challenge will be to implement a cache, that is easy to reuse for any other callback.

…p, improved docs and removed the data from metric_defs, which now only require the model as input to compute the metrics.
@Drenderer
Copy link
Owner Author

Drenderer commented Dec 11, 2025

Very cool changes. The main challenge will be to implement a cache, that is easy to reuse for any other callback.

Thanks! :)
But I disagree about the caching, the user should just provide a cached function as metric_def.

@Drenderer
Copy link
Owner Author

With the history callback now storing the metric definitions it is no longer possible to pickle it. This makes saving difficult. It would be possible to just not save the metric_defs, but that means that after loading the user needs to supply the defs again if they want to continue training. 😐

@jaosch
Copy link
Collaborator

jaosch commented Dec 15, 2025

With the history callback now storing the metric definitions it is no longer possible to pickle it. This makes saving difficult. It would be possible to just not save the metric_defs, but that means that after loading the user needs to supply the defs again if they want to continue training. 😐

This is a similar problem to model serialization with hyperparameters (see https://docs.kidger.site/equinox/examples/serialisation/). But I am not yet sure, what the solution could be.

jaosch and others added 11 commits December 15, 2025 16:53
The recent changes have been mostly reverted, because jax.tree.reduce
does not give the desired output with `operator.eq`.
Datahandler tests now pass.
This enables compatibility with any
`optax.GradientTransormationExtraArgs` without any additional
boilerplate for the user.
The `fit` function now calls the `TrainState.create` function for lack
of an `__init__`.

The `run_training_loop` function now uses a partitioned version of the
loss enable the use of arbitrary opatax optimizers again.
`TrainStatic` made compatible with `optax.GradientTransformation` by
casting them to `optay.GradientTransformationExtraArgs`. This required
the implementation of a constructor, because field converters are not
yet implemented in Python (see PEP 712)
All tests in test_training.py pass now.
…ainer. Implemented and integrated a MetricLogger. Streamlined training functionality.
@Drenderer Drenderer marked this pull request as ready for review January 27, 2026 13:38
@Drenderer Drenderer merged commit e103a67 into develop Jan 28, 2026
1 check passed
@Drenderer Drenderer deleted the feature/generalized_training branch January 28, 2026 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants