Conversation
validation loss computations.
jaosch
left a comment
There was a problem hiding this comment.
Very cool changes. The main challenge will be to implement a cache, that is easy to reuse for any other callback.
…p, improved docs and removed the data from metric_defs, which now only require the model as input to compute the metrics.
Thanks! :) |
|
With the history callback now storing the metric definitions it is no longer possible to pickle it. This makes saving difficult. It would be possible to just not save the metric_defs, but that means that after loading the user needs to supply the defs again if they want to continue training. 😐 |
This is a similar problem to model serialization with hyperparameters (see |
The recent changes have been mostly reverted, because jax.tree.reduce does not give the desired output with `operator.eq`.
Datahandler tests now pass.
This enables compatibility with any `optax.GradientTransormationExtraArgs` without any additional boilerplate for the user.
The `fit` function now calls the `TrainState.create` function for lack of an `__init__`. The `run_training_loop` function now uses a partitioned version of the loss enable the use of arbitrary opatax optimizers again.
`TrainStatic` made compatible with `optax.GradientTransformation` by casting them to `optay.GradientTransformationExtraArgs`. This required the implementation of a constructor, because field converters are not yet implemented in Python (see PEP 712)
All tests in test_training.py pass now.
…r the generalized training.
…ainer. Implemented and integrated a MetricLogger. Streamlined training functionality.
…, instead of spamming.
…more comprehensive docstring.
Key improvements:
training_loopfunction that only takes state, static, and callbacks as arguments. The loop updates the state.CallbackArgsin favor of aTrainingView, which provides read-only access to static and read-write access to the state. This greatly simplifies the callbacks.HistoryCallbackin favor of a more general logging framework using aMetricLoggercallback and a dict-likeHistoryobject. This significantly improves the separation of concerns and enables the user to easily track custom metrics.klax.fitto recreate the old behaviour with the new components.