Skip to content

Roadmap to 0.2.0 release #68

@jaosch

Description

@jaosch

Training

  • Split the training loop from klax.fit into its major components:
    • Loss and loss-gradient computation
    • Optimisation: weight update computation
    • Callbacks
    • One loop that glues these components together.
  • Turn the current CallbackArgs class into a TrainingState class (ideally as PyTree via eqx.Module). This can then act as the interaction object, tying together the components of the training loop.

Models/Architectures

Docs

Explanations

  • Description of klax.fit and its capabilities for customisation. This can be tied together with the following examples/tutorials
    • Example/ Tutorial on writing custom loss functions.
    • Example/ Tutorial on using constraints.
    • Example/ Tutorial on writing custom callbacks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    roadmapRoadmap for the future development of klax.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions