From 0b2391003fea197eaceef31e37fdd7162d85b389 Mon Sep 17 00:00:00 2001 From: jaosch Date: Fri, 12 Sep 2025 23:38:00 +0200 Subject: [PATCH 1/9] Modularized fit The `fit` function is modualized. The main loop is performed within the `_fit_core` function, which knows nothing about the data, optimizer, model, and loss. It only performs the update. The original API of `fit` is provided as a wrapper, which exposes some parameters of the training components to the user, e.g., batch axis. Different `Protocols` are implemented for the loss, the updater and the batcher, but none of them create new layers of abstraction, i.e. not new classes, but mere serve as templates for the different functions required for training. Note: Callbacks have not been implemented in this approach yet. --- klax/_new_training.py | 208 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 klax/_new_training.py diff --git a/klax/_new_training.py b/klax/_new_training.py new file mode 100644 index 0000000..124c476 --- /dev/null +++ b/klax/_new_training.py @@ -0,0 +1,208 @@ +# Copyright 2025 The Klax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements a training loop.""" + +import typing +from abc import ABC, abstractmethod +from collections.abc import Callable, Generator +from dataclasses import dataclass +from typing import Any, Protocol + +import equinox as eqx +import jax +import jax.numpy as jnp +import optax +from jaxtyping import PRNGKeyArray, PyTree, Scalar + +from ._datahandler import BatchGenerator, batch_data +from ._wrappers import apply + + +@dataclass +class TrainingState: + model: PyTree + opt_state: PyTree + step: int = 0 + + +@typing.runtime_checkable +class ValueFn(Protocol): + @abstractmethod + def __call__(self, model: PyTree, data: PyTree) -> Scalar: + pass + + +@typing.runtime_checkable +class ValueAndGradFn(Protocol): + @abstractmethod + def __call__(self, model: PyTree, data: PyTree) -> tuple[Scalar, PyTree]: + pass + + +class LossFactory(Protocol): + @abstractmethod + def __call__(self, batch_axis) -> tuple[ValueFn, ValueAndGradFn]: + pass + + +def mse(batch_axis) -> tuple[ValueFn, ValueAndGradFn]: + def value_fn(model: PyTree, data: PyTree) -> Scalar: + x, y = data + if isinstance(batch_axis, tuple): + in_axes = batch_axis[0] + else: + in_axes = batch_axis + y_pred = eqx.filter_vmap(model, in_axes=(in_axes,))(x) + return jnp.mean(jnp.square(y_pred - y)) + + def value_and_grad_fn( + model: PyTree, data: PyTree + ) -> tuple[Scalar, PyTree]: + return eqx.filter_value_and_grad(value_fn)(model, data) + + return value_fn, value_and_grad_fn + + +class Updater(Protocol): + @abstractmethod + def __call__( + self, model: PyTree, batch: PyTree, opt_state: PyTree + ) -> tuple[PyTree, PyTree]: + pass + + +class UpdaterFactory(Protocol): + @abstractmethod + def __call__( + self, + opt_update: Callable, + value_fn: ValueFn, + value_and_grad_fn: ValueAndGradFn, + ) -> Updater: + pass + + +def optax_updater(opt_update, value_fn, value_and_grad_fn) -> Updater: + def wrapper(model, batch, opt_state): + value, grad = value_and_grad_fn(model, batch) + updates, opt_state = opt_update( + grad, + opt_state, + eqx.filter(model, eqx.is_inexact_array), + value=value, + grad=grad, + value_fn=jax.tree_util.Partial(value_fn, model=model, batch=batch), + ) + model = eqx.apply_updates(model, updates) + return model, opt_state + + return wrapper + + +def _fit_core[T: eqx.Module]( + updater: Updater, + batcher: Generator[PyTree[Any], None, None], + state: TrainingState, + steps: int, +): + @eqx.filter_jit + def make_step(batch, flat_model, flat_opt_state): + # Use the unflatten trick to speed up training, + # see https://docs.kidger.site/equinox/tricks/ + model = jax.tree_util.tree_unflatten(treedef_model, flat_model) + opt_state = jax.tree_util.tree_unflatten( + treedef_opt_state, flat_opt_state + ) + + # Compute and apply the parameter updates + # params, static = eqx.partition(model, eqx.is_inexact_array) + # params = updater(params, static, batch, opt_state) + model, opt_state = updater(model, batch, opt_state) + + # Apply the Constraint in the model to ensure apply-constrains are met + # after the update. + model = apply(model) + + flat_model = jax.tree_util.tree_leaves(model) + flat_opt_state = jax.tree_util.tree_leaves(opt_state) + + return flat_model, flat_opt_state + + # Apply the Constraint in the model to ensure apply-constrains are met + # initially + state.model = apply(state.model) + + # Use the unflatten trick to speed up training, + # see https://docs.kidger.site/equinox/tricks/ + flat_model, treedef_model = jax.tree.flatten(state.model) + flat_opt_state, treedef_opt_state = jax.tree.flatten(state.opt_state) + + for state.step in range(state.step, state.step + steps + 1): + batch = next(batcher) + flat_model, flat_opt_state = make_step( + batch, flat_model, flat_opt_state + ) + + state.model = jax.tree_util.tree_unflatten(treedef_model, flat_model) + state.opt_state = jax.tree_util.tree_unflatten( + treedef_opt_state, flat_opt_state + ) + + return state + + +def fit( + model, + data, + *, + batch_size: int = 32, + batch_axis: PyTree[int | None] = 0, + steps: int = 1_000, + loss: LossFactory = mse, + optimizer: optax.GradientTransformation, + init_opt_state: PyTree[Any] = None, + batcher: BatchGenerator = batch_data, + updater: UpdaterFactory = optax_updater, + key: PRNGKeyArray, +): + state = TrainingState( + model=model, + opt_state=optimizer.init(eqx.filter(model, eqx.is_inexact_array)) + if init_opt_state is None + else init_opt_state + if init_opt_state is None + else init_opt_state, + ) + + state = _fit_core( + updater(optimizer.update, *loss(batch_axis)), + batcher( + data=data, batch_axis=batch_axis, batch_size=batch_size, key=key + ), + state, + steps, + ) + return state.model + + +if __name__ == "__main__": + x = jnp.linspace(0.0, 1.0, 2).reshape(-1, 1) + y = 2.0 * x + 1.0 + model = eqx.nn.Linear(1, 1, key=eqx.internal.GetKey()()) + model = fit( + model, (x, y), optimizer=optax.adam(1.0), key=eqx.internal.GetKey()() + ) + y_pred = jax.vmap(model)(x) + assert jnp.allclose(y_pred, y) From 75a3f96d854f719167ff0496d7c4419820288a0f Mon Sep 17 00:00:00 2001 From: jaosch Date: Sat, 13 Sep 2025 00:05:38 +0200 Subject: [PATCH 2/9] Concretization of type annotations + example for `fit_core` An example was added at the bottom of the module to show-case the flexibility of the `fit_core` method for building custom trainings. --- klax/_new_training.py | 55 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/klax/_new_training.py b/klax/_new_training.py index 124c476..ab364c3 100644 --- a/klax/_new_training.py +++ b/klax/_new_training.py @@ -87,14 +87,36 @@ class UpdaterFactory(Protocol): @abstractmethod def __call__( self, - opt_update: Callable, + opt_update: optax.TransformUpdateFn | optax.TransformUpdateExtraArgsFn, value_fn: ValueFn, value_and_grad_fn: ValueAndGradFn, ) -> Updater: pass -def optax_updater(opt_update, value_fn, value_and_grad_fn) -> Updater: +def optax_transform_update_fn_updater( + opt_update: optax.TransformUpdateFn, + value_fn: ValueFn, + value_and_grad_fn: ValueAndGradFn, +) -> Updater: + def wrapper(model, batch, opt_state): + value, grad = value_and_grad_fn(model, batch) + updates, opt_state = opt_update( + grad, + opt_state, + eqx.filter(model, eqx.is_inexact_array), + ) + model = eqx.apply_updates(model, updates) + return model, opt_state + + return wrapper + + +def optax_transform_update_fn_extra_args_updater( + opt_update: optax.TransformUpdateExtraArgsFn, + value_fn: ValueFn, + value_and_grad_fn: ValueAndGradFn, +) -> Updater: def wrapper(model, batch, opt_state): value, grad = value_and_grad_fn(model, batch) updates, opt_state = opt_update( @@ -111,7 +133,7 @@ def wrapper(model, batch, opt_state): return wrapper -def _fit_core[T: eqx.Module]( +def fit_core[T: eqx.Module]( updater: Updater, batcher: Generator[PyTree[Any], None, None], state: TrainingState, @@ -174,7 +196,7 @@ def fit( optimizer: optax.GradientTransformation, init_opt_state: PyTree[Any] = None, batcher: BatchGenerator = batch_data, - updater: UpdaterFactory = optax_updater, + updater: UpdaterFactory = optax_transform_update_fn_updater, key: PRNGKeyArray, ): state = TrainingState( @@ -186,7 +208,7 @@ def fit( else init_opt_state, ) - state = _fit_core( + state = fit_core( updater(optimizer.update, *loss(batch_axis)), batcher( data=data, batch_axis=batch_axis, batch_size=batch_size, key=key @@ -198,6 +220,7 @@ def fit( if __name__ == "__main__": + # Test fit x = jnp.linspace(0.0, 1.0, 2).reshape(-1, 1) y = 2.0 * x + 1.0 model = eqx.nn.Linear(1, 1, key=eqx.internal.GetKey()()) @@ -206,3 +229,25 @@ def fit( ) y_pred = jax.vmap(model)(x) assert jnp.allclose(y_pred, y) + + # Test fit_core + x = jnp.linspace(0.0, 1.0, 2).reshape(-1, 1) + y = 2.0 * x + 1.0 + model = eqx.nn.Linear(1, 1, key=eqx.internal.GetKey()()) + batch_axis = 0 + optimizer = optax.adam(1.0) + state = TrainingState( + model=model, + opt_state=optimizer.init(eqx.filter(model, eqx.is_inexact_array)), + ) + batcher = batch_data( + (x, y), + batch_size=32, + batch_axis=batch_axis, + key=eqx.internal.GetKey()(), + ) + loss = mse(batch_axis) + updater = optax_transform_update_fn_updater(optimizer.update, *loss) + state = fit_core(updater, batcher, state, steps=1000) + y_pred = jax.vmap(state.model)(x) + assert jnp.allclose(y_pred, y) From 08e4f6b286f31f6e4876c72b654bd44c68ee154a Mon Sep 17 00:00:00 2001 From: jaosch Date: Sat, 13 Sep 2025 00:08:31 +0200 Subject: [PATCH 3/9] Tiny fix --- klax/_new_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/klax/_new_training.py b/klax/_new_training.py index ab364c3..314ffce 100644 --- a/klax/_new_training.py +++ b/klax/_new_training.py @@ -100,7 +100,7 @@ def optax_transform_update_fn_updater( value_and_grad_fn: ValueAndGradFn, ) -> Updater: def wrapper(model, batch, opt_state): - value, grad = value_and_grad_fn(model, batch) + _, grad = value_and_grad_fn(model, batch) updates, opt_state = opt_update( grad, opt_state, From 167fc1c524838cda291f676952eca7b4792f35ed Mon Sep 17 00:00:00 2001 From: jaosch Date: Sat, 13 Sep 2025 12:03:33 +0200 Subject: [PATCH 4/9] Enable stateful batching + training state extended with treedefs The protocol for BatchGenerators has been changed to allow for a `TrainingState` to be sent to the generator. This enables statefull batching, where, e.g., the step of the model can be used as a random seed for jandom number generation. Additionally, the task of flattening and unflattening the model was used to the `TrainingState', making the `fit_core` function even more minimal. --- klax/_new_training.py | 161 +++++++++++++++++++++++++++++++++++------- 1 file changed, 135 insertions(+), 26 deletions(-) diff --git a/klax/_new_training.py b/klax/_new_training.py index 314ffce..abf3280 100644 --- a/klax/_new_training.py +++ b/klax/_new_training.py @@ -23,19 +23,77 @@ import equinox as eqx import jax import jax.numpy as jnp +import jax.random as jr +import numpy as np import optax from jaxtyping import PRNGKeyArray, PyTree, Scalar -from ._datahandler import BatchGenerator, batch_data +from ._datahandler import broadcast_and_get_size from ._wrappers import apply @dataclass class TrainingState: - model: PyTree - opt_state: PyTree + flat_model: PyTree + flat_opt_state: PyTree + treedef_model: PyTree + treedef_opt_state: PyTree step: int = 0 + def __init__(self, model: PyTree, opt_state: PyTree = None, step: int = 0): + # Apply the Constraint in the model to ensure apply-constrains are met + # initially + model = apply(model) + + # Use the unflatten trick to speed up training, + # see https://docs.kidger.site/equinox/tricks/ + flat_model, treedef_model = jax.tree.flatten(model) + flat_opt_state, treedef_opt_state = jax.tree.flatten(opt_state) + + self.flat_model = flat_model + self.flat_opt_state = flat_opt_state + self.treedef_model = treedef_model + self.treedef_opt_state = treedef_opt_state + self.step = step + + @property + def model(self) -> PyTree: + return jax.tree_util.tree_unflatten( + self.treedef_model, self.flat_model + ) + + @property + def opt_state(self) -> PyTree: + return jax.tree_util.tree_unflatten( + self.treedef_opt_state, self.flat_opt_state + ) + + def update( + self, flat_model: PyTree, flat_opt_state: PyTree, step: int + ) -> None: + self.flat_model = flat_model + self.flat_opt_state = flat_opt_state + self.step = step + + +class Callback(ABC): + """An abstract callback. + + Inherit from this class to create a custom callback. + """ + + def __call__(self, state: TrainingState) -> bool | None: + """Call after each step during training.""" + pass + + def on_training_end(self, state: TrainingState) -> None: + """Call when training ends.""" + pass + + def on_training_start(self, state: TrainingState) -> None: + """Call when training starts.""" + pass + @typing.runtime_checkable class ValueFn(Protocol): @@ -133,9 +191,69 @@ def wrapper(model, batch, opt_state): return wrapper +@typing.runtime_checkable +class Batcher(Protocol): + @abstractmethod + def __call__( + self, + data: PyTree[Any], + batch_size: int, + batch_axis: int, + *, + key: PRNGKeyArray, + ) -> Generator[PyTree[Any], TrainingState, None]: + pass + + +def stateful_batch_data( + data: PyTree[Any], + batch_size: int, + batch_axis: int, + convert_to_numpy: bool = True, + *, + key: PRNGKeyArray, # Only cor compliance with the `Batcher` protocol +) -> Generator[PyTree[Any], TrainingState, None]: + """Create a stateful batch generator that uses the step as seed.""" + batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) + + # Convert to Numpy arrays. Numpy's slicing is much faster than JAX's, so + # for fast model training steps this actually makes a huge difference! + # However, be aware that this is likely only true if JAX runs on CPU. + if convert_to_numpy: + data = jax.tree.map( + lambda x, a: x if a is None else np.array(x), + data, + batch_axis, + is_leaf=lambda x: x is None, + ) + + # Reduce batch size if the dataset has less examples than batch size + batch_size = min(batch_size, dataset_size) + + indices = jnp.arange(dataset_size) + while True: + # Store the training state as received by the `.send(state)` within + # the training loop. + state: TrainingState = yield + key = jax.random.PRNGKey(state.step) # Create key from step + perm = jr.permutation(key, indices) + (key,) = jr.split(key, 1) # Update key + start, end = 0, batch_size + while end <= dataset_size: + batch_perm = perm[start:end] + yield jax.tree.map( + lambda a, x: x if a is None else x[batch_perm], + batch_axis, + data, + is_leaf=lambda x: x is None, + ) + start = end + end = start + batch_size + + def fit_core[T: eqx.Module]( updater: Updater, - batcher: Generator[PyTree[Any], None, None], + batcher: Generator[PyTree[Any], TrainingState, None], state: TrainingState, steps: int, ): @@ -143,9 +261,9 @@ def fit_core[T: eqx.Module]( def make_step(batch, flat_model, flat_opt_state): # Use the unflatten trick to speed up training, # see https://docs.kidger.site/equinox/tricks/ - model = jax.tree_util.tree_unflatten(treedef_model, flat_model) + model = jax.tree_util.tree_unflatten(state.treedef_model, flat_model) opt_state = jax.tree_util.tree_unflatten( - treedef_opt_state, flat_opt_state + state.treedef_opt_state, flat_opt_state ) # Compute and apply the parameter updates @@ -162,26 +280,13 @@ def make_step(batch, flat_model, flat_opt_state): return flat_model, flat_opt_state - # Apply the Constraint in the model to ensure apply-constrains are met - # initially - state.model = apply(state.model) - - # Use the unflatten trick to speed up training, - # see https://docs.kidger.site/equinox/tricks/ - flat_model, treedef_model = jax.tree.flatten(state.model) - flat_opt_state, treedef_opt_state = jax.tree.flatten(state.opt_state) - for state.step in range(state.step, state.step + steps + 1): - batch = next(batcher) - flat_model, flat_opt_state = make_step( - batch, flat_model, flat_opt_state + next(batcher) + batch = batcher.send(state) # Send the state back to the batcher + state.flat_model, state.flat_opt_state = make_step( + batch, state.flat_model, state.flat_opt_state ) - state.model = jax.tree_util.tree_unflatten(treedef_model, flat_model) - state.opt_state = jax.tree_util.tree_unflatten( - treedef_opt_state, flat_opt_state - ) - return state @@ -195,7 +300,7 @@ def fit( loss: LossFactory = mse, optimizer: optax.GradientTransformation, init_opt_state: PyTree[Any] = None, - batcher: BatchGenerator = batch_data, + batcher: Batcher = stateful_batch_data, updater: UpdaterFactory = optax_transform_update_fn_updater, key: PRNGKeyArray, ): @@ -240,14 +345,18 @@ def fit( model=model, opt_state=optimizer.init(eqx.filter(model, eqx.is_inexact_array)), ) - batcher = batch_data( + batcher = stateful_batch_data( (x, y), batch_size=32, batch_axis=batch_axis, - key=eqx.internal.GetKey()(), + key=eqx.internal.GetKey()(), # Unused ) loss = mse(batch_axis) updater = optax_transform_update_fn_updater(optimizer.update, *loss) state = fit_core(updater, batcher, state, steps=1000) y_pred = jax.vmap(state.model)(x) assert jnp.allclose(y_pred, y) + + import pprint + + pprint.pp(state) From a7fc3a2daeb0e055552c30464f5bfd1f35ce023d Mon Sep 17 00:00:00 2001 From: jaosch Date: Sat, 13 Sep 2025 13:36:03 +0200 Subject: [PATCH 5/9] Added EvaluationContext and TrainingContext The `EvaluationContext` takes care of the lazy loss evaluation. It receives the training state as input and stores the loss and validation loss in a cache. The cache is updated whenever the training step changed. The `TrainingContext` is a simple wrapper around a `TrainingState` and an `EvaluatorContext`. It serves as the main input to every `Callback`. This new implementation nicely separates the `dump` training state from the data and the evaluation. --- klax/_new_training.py | 257 +++++++++++++++++++++++++++++++++--------- 1 file changed, 201 insertions(+), 56 deletions(-) diff --git a/klax/_new_training.py b/klax/_new_training.py index abf3280..e701676 100644 --- a/klax/_new_training.py +++ b/klax/_new_training.py @@ -16,9 +16,9 @@ import typing from abc import ABC, abstractmethod -from collections.abc import Callable, Generator -from dataclasses import dataclass -from typing import Any, Protocol +from collections.abc import Callable, Generator, Iterable +from dataclasses import dataclass, field +from typing import Any, Protocol, Self import equinox as eqx import jax @@ -32,15 +32,17 @@ from ._wrappers import apply -@dataclass class TrainingState: - flat_model: PyTree - flat_opt_state: PyTree - treedef_model: PyTree - treedef_opt_state: PyTree - step: int = 0 - - def __init__(self, model: PyTree, opt_state: PyTree = None, step: int = 0): + _flat_model: PyTree + _flat_opt_state: PyTree + _treedef_model: PyTree + _treedef_opt_state: PyTree + _step: int + _cache: dict[str, Any] = {} + + def __init__( + self, model: PyTree, opt_state: PyTree = None, initial_step: int = 0 + ): # Apply the Constraint in the model to ensure apply-constrains are met # initially model = apply(model) @@ -50,49 +52,78 @@ def __init__(self, model: PyTree, opt_state: PyTree = None, step: int = 0): flat_model, treedef_model = jax.tree.flatten(model) flat_opt_state, treedef_opt_state = jax.tree.flatten(opt_state) - self.flat_model = flat_model - self.flat_opt_state = flat_opt_state - self.treedef_model = treedef_model - self.treedef_opt_state = treedef_opt_state - self.step = step + self._flat_model = flat_model + self._flat_opt_state = flat_opt_state + self._treedef_model = treedef_model + self._treedef_opt_state = treedef_opt_state + self._step = initial_step + + @staticmethod + def _lazy_chached_property(fun: Callable) -> property: + """Turn a public method into a lazily evaluated property. + + The return value of ``fun`` is stored in the ``_cache`` dictionary of + the current object using the function name as key. If the name is + already in ``_cache`` then the cached value is simply returned, + without evaluating ``fun``. + + Args: + fun: Method to wrap. + + Returns: + Wrapped method as a property. + + """ + attr_name = fun.__name__ + + def wrapper(self: Self): + if attr_name not in self._cache: + self._cache.setdefault(attr_name, fun(self)) + return self._cache.get(attr_name) + + wrapper.__doc__ = fun.__doc__ + + return property(wrapper) @property + def flat_model(self) -> PyTree: + return self._flat_model + + @property + def flat_opt_state(self) -> PyTree: + return self._flat_opt_state + + @property + def step(self) -> int: + return self._step + + @_lazy_chached_property def model(self) -> PyTree: return jax.tree_util.tree_unflatten( - self.treedef_model, self.flat_model + self._treedef_model, self._flat_model ) - @property + @_lazy_chached_property def opt_state(self) -> PyTree: return jax.tree_util.tree_unflatten( - self.treedef_opt_state, self.flat_opt_state + self._treedef_opt_state, self._flat_opt_state ) - def update( - self, flat_model: PyTree, flat_opt_state: PyTree, step: int - ) -> None: - self.flat_model = flat_model - self.flat_opt_state = flat_opt_state - self.step = step - - -class Callback(ABC): - """An abstract callback. - - Inherit from this class to create a custom callback. - """ + @property + def treedef_model(self) -> PyTree: + return self._treedef_model - def __call__(self, state: TrainingState) -> bool | None: - """Call after each step during training.""" - pass + @property + def treedef_opt_state(self) -> PyTree: + return self._treedef_opt_state - def on_training_end(self, state: TrainingState) -> None: - """Call when training ends.""" - pass + def update(self, flat_model: PyTree, flat_opt_state: PyTree): + self._flat_model = flat_model + self._flat_opt_state = flat_opt_state + self._step += self._step - def on_training_start(self, state: TrainingState) -> None: - """Call when training starts.""" - pass + # Clear cache + self._cache.clear() @typing.runtime_checkable @@ -251,19 +282,102 @@ def stateful_batch_data( end = start + batch_size +@dataclass +class EvaluationContext: + value_fn: ValueFn + data: PyTree[Any] + val_data: PyTree[Any] | None = None + _cached_step: int | None = None + _cache: dict[str, Any] = field(default_factory=dict) + + def _ensure_step(self, state: TrainingState): + if self._cached_step != state.step: + self._cache.clear() + self._cached_step = state.step + + @eqx.filter_jit + def _loss_impl(self, model: PyTree, batch: PyTree[Any]): + return self.value_fn(model, batch) + + def loss(self, state: TrainingState) -> Scalar: + self._ensure_step(state) + if "loss" not in self._cache: + self._cache["loss"] = self._loss_impl(state.model, self.data) + return self._cache["loss"] + + def val_loss(self, state: TrainingState) -> Scalar | None: + self._ensure_step(state) + if self.val_data is None: + return None + + if "val_loss" not in self._cache: + self._cache["val_loss"] = self._loss_impl( + state.model, self.val_data + ) + return self._cache["val_loss"] + + +@dataclass +class TrainingContext: + state: TrainingState + evaluator: EvaluationContext + + @property + def model(self) -> PyTree: + return self.state.model + + @property + def optimizer_state(self) -> PyTree: + return self.state.opt_state + + @property + def step(self) -> int: + return self.state.step + + @property + def loss(self) -> Scalar: + return self.evaluator.loss(self.state) + + @property + def val_loss(self) -> Scalar | None: + return self.evaluator.val_loss(self.state) + + +class Callback(ABC): + """An abstract callback. + + Inherit from this class to create a custom callback. + """ + + def __call__(self, ctx: TrainingContext) -> bool | None: + """Call after each step during training.""" + pass + + def on_training_end(self, ctx: TrainingContext) -> None: + """Call when training ends.""" + pass + + def on_training_start(self, ctx: TrainingContext) -> None: + """Call when training starts.""" + pass + + def fit_core[T: eqx.Module]( updater: Updater, batcher: Generator[PyTree[Any], TrainingState, None], - state: TrainingState, + ctx: TrainingContext, steps: int, + callbacks: Iterable[Callback] | None = None, ): @eqx.filter_jit def make_step(batch, flat_model, flat_opt_state): # Use the unflatten trick to speed up training, # see https://docs.kidger.site/equinox/tricks/ - model = jax.tree_util.tree_unflatten(state.treedef_model, flat_model) + model = jax.tree_util.tree_unflatten( + ctx.state.treedef_model, flat_model + ) opt_state = jax.tree_util.tree_unflatten( - state.treedef_opt_state, flat_opt_state + ctx.state.treedef_opt_state, flat_opt_state ) # Compute and apply the parameter updates @@ -280,14 +394,30 @@ def make_step(batch, flat_model, flat_opt_state): return flat_model, flat_opt_state - for state.step in range(state.step, state.step + steps + 1): + # Make callbacks iterable + callbacks = [] if callbacks is None else list(callbacks) + + for callback in callbacks: + callback.on_training_start(ctx) + + for _ in range(steps): next(batcher) - batch = batcher.send(state) # Send the state back to the batcher - state.flat_model, state.flat_opt_state = make_step( - batch, state.flat_model, state.flat_opt_state + batch = batcher.send(ctx.state) # Send the state back to the batcher + flat_model, flat_opt_state = make_step( + batch, ctx.state.flat_model, ctx.state.flat_opt_state ) - return state + ctx.state.update(flat_model, flat_opt_state) + + # Run all callbacks and break if any of them request termination of + # the training loop. + # Note! The square brackets are important. Otherwise the loop is + # terminated with the first callback that returns true. But we want + # to run all callbacks first and then decide, whether to terminate. + if any([callback(ctx) for callback in callbacks]): + break + + return ctx def fit( @@ -296,6 +426,7 @@ def fit( *, batch_size: int = 32, batch_axis: PyTree[int | None] = 0, + validation_data: PyTree[Any] = None, steps: int = 1_000, loss: LossFactory = mse, optimizer: optax.GradientTransformation, @@ -304,6 +435,8 @@ def fit( updater: UpdaterFactory = optax_transform_update_fn_updater, key: PRNGKeyArray, ): + value_fn, value_and_grad_fn = loss(batch_axis) + evaluator = EvaluationContext(value_fn, data, val_data=validation_data) state = TrainingState( model=model, opt_state=optimizer.init(eqx.filter(model, eqx.is_inexact_array)) @@ -312,16 +445,20 @@ def fit( if init_opt_state is None else init_opt_state, ) + ctx = TrainingContext( + state=state, + evaluator=evaluator, + ) - state = fit_core( - updater(optimizer.update, *loss(batch_axis)), + ctx = fit_core( + updater(optimizer.update, value_fn, value_and_grad_fn), batcher( data=data, batch_axis=batch_axis, batch_size=batch_size, key=key ), - state, + ctx, steps, ) - return state.model + return ctx.state.model if __name__ == "__main__": @@ -338,22 +475,30 @@ def fit( # Test fit_core x = jnp.linspace(0.0, 1.0, 2).reshape(-1, 1) y = 2.0 * x + 1.0 + data = (x, y) model = eqx.nn.Linear(1, 1, key=eqx.internal.GetKey()()) batch_axis = 0 optimizer = optax.adam(1.0) + value_fn, value_and_grad_fn = mse(batch_axis) + evaluator = EvaluationContext(value_fn, (x, y)) state = TrainingState( model=model, opt_state=optimizer.init(eqx.filter(model, eqx.is_inexact_array)), ) + ctx = TrainingContext( + state=state, + evaluator=evaluator, + ) batcher = stateful_batch_data( (x, y), batch_size=32, batch_axis=batch_axis, key=eqx.internal.GetKey()(), # Unused ) - loss = mse(batch_axis) - updater = optax_transform_update_fn_updater(optimizer.update, *loss) - state = fit_core(updater, batcher, state, steps=1000) + updater = optax_transform_update_fn_updater( + optimizer.update, value_fn, value_and_grad_fn + ) + state = fit_core(updater, batcher, ctx, steps=1000) y_pred = jax.vmap(state.model)(x) assert jnp.allclose(y_pred, y) From 5294cc7c4b80ac2f3886442970f317fc276adb8f Mon Sep 17 00:00:00 2001 From: jaosch Date: Sat, 13 Sep 2025 13:38:40 +0200 Subject: [PATCH 6/9] Added call of on_training_end --- klax/_new_training.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/klax/_new_training.py b/klax/_new_training.py index e701676..7c69596 100644 --- a/klax/_new_training.py +++ b/klax/_new_training.py @@ -417,6 +417,10 @@ def make_step(batch, flat_model, flat_opt_state): if any([callback(ctx) for callback in callbacks]): break + # Call callbacks after training + for callback in callbacks: + callback.on_training_end(ctx) + return ctx From 66dcc6c948b6631e497f04831f856a91dab705a3 Mon Sep 17 00:00:00 2001 From: jaosch Date: Sat, 13 Sep 2025 13:50:36 +0200 Subject: [PATCH 7/9] Pass full `TrainingContext` into the batcher --- klax/_new_training.py | 124 +++++++++++++++++++++--------------------- 1 file changed, 62 insertions(+), 62 deletions(-) diff --git a/klax/_new_training.py b/klax/_new_training.py index 7c69596..bf85d4f 100644 --- a/klax/_new_training.py +++ b/klax/_new_training.py @@ -222,66 +222,6 @@ def wrapper(model, batch, opt_state): return wrapper -@typing.runtime_checkable -class Batcher(Protocol): - @abstractmethod - def __call__( - self, - data: PyTree[Any], - batch_size: int, - batch_axis: int, - *, - key: PRNGKeyArray, - ) -> Generator[PyTree[Any], TrainingState, None]: - pass - - -def stateful_batch_data( - data: PyTree[Any], - batch_size: int, - batch_axis: int, - convert_to_numpy: bool = True, - *, - key: PRNGKeyArray, # Only cor compliance with the `Batcher` protocol -) -> Generator[PyTree[Any], TrainingState, None]: - """Create a stateful batch generator that uses the step as seed.""" - batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) - - # Convert to Numpy arrays. Numpy's slicing is much faster than JAX's, so - # for fast model training steps this actually makes a huge difference! - # However, be aware that this is likely only true if JAX runs on CPU. - if convert_to_numpy: - data = jax.tree.map( - lambda x, a: x if a is None else np.array(x), - data, - batch_axis, - is_leaf=lambda x: x is None, - ) - - # Reduce batch size if the dataset has less examples than batch size - batch_size = min(batch_size, dataset_size) - - indices = jnp.arange(dataset_size) - while True: - # Store the training state as received by the `.send(state)` within - # the training loop. - state: TrainingState = yield - key = jax.random.PRNGKey(state.step) # Create key from step - perm = jr.permutation(key, indices) - (key,) = jr.split(key, 1) # Update key - start, end = 0, batch_size - while end <= dataset_size: - batch_perm = perm[start:end] - yield jax.tree.map( - lambda a, x: x if a is None else x[batch_perm], - batch_axis, - data, - is_leaf=lambda x: x is None, - ) - start = end - end = start + batch_size - - @dataclass class EvaluationContext: value_fn: ValueFn @@ -362,9 +302,69 @@ def on_training_start(self, ctx: TrainingContext) -> None: pass +@typing.runtime_checkable +class Batcher(Protocol): + @abstractmethod + def __call__( + self, + data: PyTree[Any], + batch_size: int, + batch_axis: int, + *, + key: PRNGKeyArray, + ) -> Generator[PyTree[Any], TrainingState, None]: + pass + + +def stateful_batch_data( + data: PyTree[Any], + batch_size: int, + batch_axis: int, + convert_to_numpy: bool = True, + *, + key: PRNGKeyArray, # Only cor compliance with the `Batcher` protocol +) -> Generator[PyTree[Any], TrainingContext, None]: + """Create a stateful batch generator that uses the step as seed.""" + batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) + + # Convert to Numpy arrays. Numpy's slicing is much faster than JAX's, so + # for fast model training steps this actually makes a huge difference! + # However, be aware that this is likely only true if JAX runs on CPU. + if convert_to_numpy: + data = jax.tree.map( + lambda x, a: x if a is None else np.array(x), + data, + batch_axis, + is_leaf=lambda x: x is None, + ) + + # Reduce batch size if the dataset has less examples than batch size + batch_size = min(batch_size, dataset_size) + + indices = jnp.arange(dataset_size) + while True: + # Store the training state as received by the `.send(state)` within + # the training loop. + ctx: TrainingContext = yield + key = jax.random.PRNGKey(ctx.state.step) # Create key from step + perm = jr.permutation(key, indices) + (key,) = jr.split(key, 1) # Update key + start, end = 0, batch_size + while end <= dataset_size: + batch_perm = perm[start:end] + yield jax.tree.map( + lambda a, x: x if a is None else x[batch_perm], + batch_axis, + data, + is_leaf=lambda x: x is None, + ) + start = end + end = start + batch_size + + def fit_core[T: eqx.Module]( updater: Updater, - batcher: Generator[PyTree[Any], TrainingState, None], + batcher: Generator[PyTree[Any], TrainingContext, None], ctx: TrainingContext, steps: int, callbacks: Iterable[Callback] | None = None, @@ -402,7 +402,7 @@ def make_step(batch, flat_model, flat_opt_state): for _ in range(steps): next(batcher) - batch = batcher.send(ctx.state) # Send the state back to the batcher + batch = batcher.send(ctx) # Send the context flat_model, flat_opt_state = make_step( batch, ctx.state.flat_model, ctx.state.flat_opt_state ) From e35d89a18bc5303b03e6d87d473cc883b2de9f31 Mon Sep 17 00:00:00 2001 From: jaosch Date: Sat, 13 Sep 2025 16:39:42 +0200 Subject: [PATCH 8/9] New HistoryCallback + distribution over multiple files --- klax/__init__.py | 17 +- klax/_callbacks.py | 317 ++++++++++++++------------ klax/_context.py | 235 +++++++++++++++++++ klax/_datahandler.py | 203 +++++++++++------ klax/_losses.py | 104 +++------ klax/_new_training.py | 511 ------------------------------------------ klax/_training.py | 346 ++++++++++++---------------- klax/_updaters.py | 67 ++++++ 8 files changed, 785 insertions(+), 1015 deletions(-) create mode 100644 klax/_context.py delete mode 100644 klax/_new_training.py create mode 100644 klax/_updaters.py diff --git a/klax/__init__.py b/klax/__init__.py index 385282b..6cd54ae 100644 --- a/klax/__init__.py +++ b/klax/__init__.py @@ -16,12 +16,15 @@ from ._callbacks import ( Callback as Callback, ) -from ._callbacks import ( - CallbackArgs as CallbackArgs, -) from ._callbacks import ( HistoryCallback as HistoryCallback, ) +from ._context import ( + EvaluationContext, + TimingInfo, + TrainingContext, + TrainingState, +) from ._datahandler import ( BatchGenerator as BatchGenerator, ) @@ -32,13 +35,7 @@ split_data as split_data, ) from ._losses import ( - MAE as MAE, -) -from ._losses import ( - MSE as MSE, -) -from ._losses import ( - Loss as Loss, + LossFactory as LossFactory, ) from ._losses import ( mae as mae, diff --git a/klax/_callbacks.py b/klax/_callbacks.py index c4df84a..d1acd5a 100644 --- a/klax/_callbacks.py +++ b/klax/_callbacks.py @@ -16,7 +16,7 @@ import importlib import pickle import time -from abc import ABC +from abc import ABC, abstractmethod from collections.abc import Callable from pathlib import Path from typing import Any, Self @@ -24,132 +24,7 @@ import jax from jaxtyping import PyTree, PyTreeDef, Scalar - -class CallbackArgs: - """A callback argument designed to work in conjunction with [`klax.fit`][]. - - This class should not be instantiated directly. An instance of this class - is passed to every callback object in the fit function. When writing a - custom callback, use the properties of this class to access the current - model, optimizer state, training data, and validation data during training. - - This class implements cached and lazy-evaluated values via property - methods. This means that properties like ``loss`` are only calculated if - they are used and are stored such that they are not calculated multiple - times. - """ - - step: int #: Current step-count of the training. - time_on_last_update: float #: Global time of the last :meth:`update` call. - data: PyTree #: PyTree of the training data. - val_data: PyTree | None #: PyTree of the validation data. - _treedef_model: PyTreeDef - _flat_model: list - _treedef_opt_state: PyTreeDef - _flat_opt_state: list - _cache: dict = {} - _get_loss: Callable[[PyTree, PyTree], Scalar] - _start_time: float - - def __init__( - self, - get_loss: Callable[[PyTree, PyTree], Scalar], - treedef_model: PyTreeDef, - treedef_opt_state: PyTreeDef, - data: PyTree, - val_data: PyTree | None = None, - ): - """Initialize the callback arguments object. - - Args: - get_loss: Function that takes a model and a batch of data and - returns the loss. - treedef_model: Tree structure of the model. - treedef_opt_state: Tree structure of the :py:mod:`optax` optimizer. - data: PyTree of the training data. - val_data: PyTree of the validation data. If None, no validation - loss is calculated and the property :py:attr:`val_loss` will - return None. - - """ - self.data = data - self.val_data = val_data - self._get_loss = get_loss - self._treedef_model = treedef_model - self._treedef_opt_state = treedef_opt_state - - def update(self, flat_model: PyTree, flat_opt_state: PyTree, step: int): - """Update the object with the current model and optimizer state. - - This method is called repeatedly in [`klax.fit`][]. - - Args: - flat_model: Flattened PyTree of the model. - flat_opt_state: Flattened PyTree of the `optax` - optimizer. - step: Current step-count of the training. - - """ - self._flat_model = flat_model - self._flat_opt_state = flat_opt_state - self.step = step - self.time_on_last_update = time.time() - - # Clear cache - self._cache = {} - - @staticmethod - def _lazy_evaluated_and_cached(fun: Callable[[Any], Any]) -> property: - """Turn a public method into a property. - - The return value of ``fun`` is stored in the ``_cache`` dictionary of - the current object using the function name as key. If the name is - already in ``_cache`` then the cached value is simply returned, - without evaluating ``fun``. - - Args: - fun: Method to wrap. - - Returns: - Wrapped method as a property. - - """ - attr_name = fun.__name__ - - def wrapper(self: Self): - if attr_name not in self._cache: - self._cache.setdefault(attr_name, fun(self)) - return self._cache.get(attr_name) - - wrapper.__doc__ = fun.__doc__ - - return property(wrapper) - - @_lazy_evaluated_and_cached - def model(self): - """Lazy-evaluated and cached model.""" - return jax.tree_util.tree_unflatten( - self._treedef_model, self._flat_model - ) - - @_lazy_evaluated_and_cached - def opt_state(self): - """Lazy-evaluated and cached optimizer state.""" - return jax.tree_util.tree_unflatten( - self._treedef_opt_state, self._flat_opt_state - ) - - @_lazy_evaluated_and_cached - def loss(self): - """Lazy-evaluated and cached training loss.""" - return self._get_loss(self.model, self.data) - - @_lazy_evaluated_and_cached - def val_loss(self) -> Scalar | None: - """Lazy-evaluated and cached validation loss.""" - if self.val_data is None: - return None - return self._get_loss(self.model, self.val_data) +from ._context import TrainingContext class Callback(ABC): @@ -158,19 +33,168 @@ class Callback(ABC): Inherit from this class to create a custom callback. """ - def __call__(self, cbargs: CallbackArgs) -> bool | None: + @abstractmethod + def __call__(self, ctx: TrainingContext) -> bool | None: """Call after each step during training.""" pass - def on_training_end(self, cbargs: CallbackArgs) -> None: + @abstractmethod + def on_training_end(self, ctx: TrainingContext) -> None: """Call when training ends.""" pass - def on_training_start(self, cbargs: CallbackArgs) -> None: + @abstractmethod + def on_training_start(self, ctx: TrainingContext) -> None: """Call when training starts.""" pass +# class CallbackArgs: +# """A callback argument designed to work in conjunction with [`klax.fit`][]. + +# This class should not be instantiated directly. An instance of this class +# is passed to every callback object in the fit function. When writing a +# custom callback, use the properties of this class to access the current +# model, optimizer state, training data, and validation data during training. + +# This class implements cached and lazy-evaluated values via property +# methods. This means that properties like ``loss`` are only calculated if +# they are used and are stored such that they are not calculated multiple +# times. +# """ + +# step: int #: Current step-count of the training. +# time_on_last_update: float #: Global time of the last :meth:`update` call. +# data: PyTree #: PyTree of the training data. +# val_data: PyTree | None #: PyTree of the validation data. +# _treedef_model: PyTreeDef +# _flat_model: list +# _treedef_opt_state: PyTreeDef +# _flat_opt_state: list +# _cache: dict = {} +# _get_loss: Callable[[PyTree, PyTree], Scalar] +# _start_time: float + +# def __init__( +# self, +# get_loss: Callable[[PyTree, PyTree], Scalar], +# treedef_model: PyTreeDef, +# treedef_opt_state: PyTreeDef, +# data: PyTree, +# val_data: PyTree | None = None, +# ): +# """Initialize the callback arguments object. + +# Args: +# get_loss: Function that takes a model and a batch of data and +# returns the loss. +# treedef_model: Tree structure of the model. +# treedef_opt_state: Tree structure of the :py:mod:`optax` optimizer. +# data: PyTree of the training data. +# val_data: PyTree of the validation data. If None, no validation +# loss is calculated and the property :py:attr:`val_loss` will +# return None. + +# """ +# self.data = data +# self.val_data = val_data +# self._get_loss = get_loss +# self._treedef_model = treedef_model +# self._treedef_opt_state = treedef_opt_state + +# def update(self, flat_model: PyTree, flat_opt_state: PyTree, step: int): +# """Update the object with the current model and optimizer state. + +# This method is called repeatedly in [`klax.fit`][]. + +# Args: +# flat_model: Flattened PyTree of the model. +# flat_opt_state: Flattened PyTree of the `optax` +# optimizer. +# step: Current step-count of the training. + +# """ +# self._flat_model = flat_model +# self._flat_opt_state = flat_opt_state +# self.step = step +# self.time_on_last_update = time.time() + +# # Clear cache +# self._cache = {} + +# @staticmethod +# def _lazy_evaluated_and_cached(fun: Callable[[Any], Any]) -> property: +# """Turn a public method into a property. + +# The return value of ``fun`` is stored in the ``_cache`` dictionary of +# the current object using the function name as key. If the name is +# already in ``_cache`` then the cached value is simply returned, +# without evaluating ``fun``. + +# Args: +# fun: Method to wrap. + +# Returns: +# Wrapped method as a property. + +# """ +# attr_name = fun.__name__ + +# def wrapper(self: Self): +# if attr_name not in self._cache: +# self._cache.setdefault(attr_name, fun(self)) +# return self._cache.get(attr_name) + +# wrapper.__doc__ = fun.__doc__ + +# return property(wrapper) + +# @_lazy_evaluated_and_cached +# def model(self): +# """Lazy-evaluated and cached model.""" +# return jax.tree_util.tree_unflatten( +# self._treedef_model, self._flat_model +# ) + +# @_lazy_evaluated_and_cached +# def opt_state(self): +# """Lazy-evaluated and cached optimizer state.""" +# return jax.tree_util.tree_unflatten( +# self._treedef_opt_state, self._flat_opt_state +# ) + +# @_lazy_evaluated_and_cached +# def loss(self): +# """Lazy-evaluated and cached training loss.""" +# return self._get_loss(self.model, self.data) + +# @_lazy_evaluated_and_cached +# def val_loss(self) -> Scalar | None: +# """Lazy-evaluated and cached validation loss.""" +# if self.val_data is None: +# return None +# return self._get_loss(self.model, self.val_data) + + +# class Callback(ABC): +# """An abstract callback. + +# Inherit from this class to create a custom callback. +# """ + +# def __call__(self, cbargs: CallbackArgs) -> bool | None: +# """Call after each step during training.""" +# pass + +# def on_training_end(self, cbargs: CallbackArgs) -> None: +# """Call when training ends.""" +# pass + +# def on_training_start(self, cbargs: CallbackArgs) -> None: +# """Call when training starts.""" +# pass + + class HistoryCallback(Callback): """Default callback for logging a training process. @@ -181,8 +205,6 @@ class HistoryCallback(Callback): steps: list #: List of steps at which the losses were recorded. loss: list val_loss: list - last_start_time: float # start time of the last training - last_end_time: float # End time of the last training training_time: float = 0 # Total training time of all trainings verbose: bool step_offset: int = 0 # Potential offset due to previous trainings @@ -211,44 +233,45 @@ def __repr__(self): f"verbose={self.verbose})" ) - def __call__(self, cbargs: CallbackArgs): + def __call__(self, ctx: TrainingContext): """Record the losses and step count. Called at each step during training. """ - if cbargs.step % self.log_every == 0: - self.steps.append(self.step_offset + cbargs.step) - self.loss.append(cbargs.loss) - self.val_loss.append(cbargs.val_loss) + if ctx.step % self.log_every == 0: + self.steps.append(self.step_offset + ctx.step) + self.loss.append(ctx.loss) + self.val_loss.append(ctx.val_loss) # Print message if self.verbose: - message = f"Step: {cbargs.step}, Loss: {cbargs.loss:.3e}" - if cbargs.val_data is not None: - message += f", Validation loss: {cbargs.val_loss:.3e}" + message = f"Step: {ctx.step}, Loss: {ctx.loss:.3e}" + if ctx.val_data is not None: + message += f", Validation loss: {ctx.val_loss:.3e}" print(message) - def on_training_start(self, cbargs: CallbackArgs): + def on_training_start(self, ctx: TrainingContext): """Initialize the training start time. Called at beginning of training. """ - self.last_start_time = cbargs.time_on_last_update + assert ctx.start_time is not None + self.last_start_time = ctx.start_time if self.steps: # If there are already steps, we assume that this is a continuation # of a training. self.step_offset = self.steps[-1] else: - self(cbargs) + self(ctx) - def on_training_end(self, cbargs: CallbackArgs): + def on_training_end(self, ctx: TrainingContext): """Record the training end time and the last optimizer state. Called at end of training. """ - self.last_end_time = cbargs.time_on_last_update - self.training_time += self.last_end_time - self.last_start_time - self.last_opt_state = cbargs.opt_state + assert ctx.total_time is not None + self.training_time += ctx.total_time + self.last_opt_state = ctx.opt_state if self.verbose: print( f"Training took: { diff --git a/klax/_context.py b/klax/_context.py new file mode 100644 index 0000000..b86f722 --- /dev/null +++ b/klax/_context.py @@ -0,0 +1,235 @@ +import time +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Self + +import equinox as eqx +import jax +from jaxtyping import PyTree, Scalar + +from ._losses import ValueFn +from ._wrappers import apply + + +class TrainingState: + _flat_model: PyTree + _flat_opt_state: PyTree + _treedef_model: PyTree + _treedef_opt_state: PyTree + _step: int + _cache: dict[str, Any] = {} + + def __init__( + self, model: PyTree, opt_state: PyTree = None, initial_step: int = 0 + ): + # Apply the Constraint in the model to ensure apply-constrains are met + # initially + model = apply(model) + + # Use the unflatten trick to speed up training, + # see https://docs.kidger.site/equinox/tricks/ + flat_model, treedef_model = jax.tree.flatten(model) + flat_opt_state, treedef_opt_state = jax.tree.flatten(opt_state) + + self._flat_model = flat_model + self._flat_opt_state = flat_opt_state + self._treedef_model = treedef_model + self._treedef_opt_state = treedef_opt_state + self._step = initial_step + + @staticmethod + def _lazy_chached_property(fun: Callable) -> property: + """Turn a public method into a lazily evaluated property. + + The return value of ``fun`` is stored in the ``_cache`` dictionary of + the current object using the function name as key. If the name is + already in ``_cache`` then the cached value is simply returned, + without evaluating ``fun``. + + Args: + fun: Method to wrap. + + Returns: + Wrapped method as a property. + + """ + attr_name = fun.__name__ + + def wrapper(self: Self): + if attr_name not in self._cache: + self._cache.setdefault(attr_name, fun(self)) + return self._cache.get(attr_name) + + wrapper.__doc__ = fun.__doc__ + + return property(wrapper) + + @property + def flat_model(self) -> PyTree: + return self._flat_model + + @property + def flat_opt_state(self) -> PyTree: + return self._flat_opt_state + + @property + def step(self) -> int: + return self._step + + @_lazy_chached_property + def model(self) -> PyTree: + return jax.tree_util.tree_unflatten( + self._treedef_model, self._flat_model + ) + + @_lazy_chached_property + def opt_state(self) -> PyTree: + return jax.tree_util.tree_unflatten( + self._treedef_opt_state, self._flat_opt_state + ) + + @property + def treedef_model(self) -> PyTree: + return self._treedef_model + + @property + def treedef_opt_state(self) -> PyTree: + return self._treedef_opt_state + + def update(self, flat_model: PyTree, flat_opt_state: PyTree): + self._flat_model = flat_model + self._flat_opt_state = flat_opt_state + self._step += 1 + + # Clear cache + self._cache.clear() + + +@dataclass +class EvaluationContext: + value_fn: ValueFn + data: PyTree[Any] + val_data: PyTree[Any] | None = None + _cached_step: int | None = None + _cache: dict[str, Any] = field(default_factory=dict) + + def _ensure_step(self, state: TrainingState): + if self._cached_step != state.step: + self._cache.clear() + self._cached_step = state.step + + @staticmethod + @eqx.filter_jit + def _loss_impl(value_fn: ValueFn, model: PyTree, batch: PyTree[Any]): + return value_fn(model, batch) + + def loss(self, state: TrainingState) -> Scalar: + self._ensure_step(state) + if "loss" not in self._cache: + self._cache["loss"] = self._loss_impl( + self.value_fn, state.model, self.data + ) + return self._cache["loss"] + + def val_loss(self, state: TrainingState) -> Scalar | None: + self._ensure_step(state) + if self.val_data is None: + return None + + if "val_loss" not in self._cache: + self._cache["val_loss"] = self._loss_impl( + self.value_fn, state.model, self.val_data + ) + return self._cache["val_loss"] + + +@dataclass +class TimingInfo: + start_time: float | None = None + total_time: float = 0.0 + time_of_last_update: float | None = None + + def update(self): + time_of_last_update = time.time() + if self.start_time is None: + self.start_time = time_of_last_update + else: + self.total_time = time_of_last_update - self.start_time + self.time_of_last_update = time_of_last_update + + +@dataclass +class TrainingContext: + _state: TrainingState + _evaluator: EvaluationContext + _timer: TimingInfo + + def __init__( + self, + state: TrainingState, + evaluator: EvaluationContext, + timing: TimingInfo, + ): + self._state = state + self._evaluator = evaluator + self._timer = timing + + def update(self, flat_model: PyTree, flat_opt_state: PyTree) -> None: + self._state.update(flat_model, flat_opt_state) + self._timer.update() + + @property + def flat_opt_state(self) -> PyTree: + return self._state.flat_opt_state + + @property + def flat_model(self) -> PyTree: + return self._state.flat_model + + @property + def model(self) -> PyTree: + return self._state.model + + @property + def treedef_model(self) -> PyTree: + return self._state.treedef_model + + @property + def opt_state(self) -> PyTree: + return self._state.opt_state + + @property + def treedef_opt_state(self) -> PyTree: + return self._state.treedef_opt_state + + @property + def step(self) -> int: + return self._state.step + + @property + def loss(self) -> Scalar: + return self._evaluator.loss(self._state) + + @property + def val_loss(self) -> Scalar | None: + return self._evaluator.val_loss(self._state) + + @property + def data(self) -> PyTree: + return self._evaluator.data + + @property + def val_data(self) -> PyTree: + return self._evaluator.val_data + + @property + def time_of_last_update(self) -> float | None: + return self._timer.time_of_last_update + + @property + def start_time(self) -> float | None: + return self._timer.start_time + + @property + def total_time(self) -> float | None: + return self._timer.total_time diff --git a/klax/_datahandler.py b/klax/_datahandler.py index a9aa536..05b0514 100644 --- a/klax/_datahandler.py +++ b/klax/_datahandler.py @@ -16,6 +16,7 @@ import typing import warnings +from abc import abstractmethod from collections.abc import Generator, Sequence from typing import Any, Protocol @@ -26,6 +27,10 @@ import numpy as np from jaxtyping import PRNGKeyArray, PyTree +from ._context import TrainingContext + +BatchGenerator = Generator[PyTree[Any], TrainingContext, None] + def broadcast_and_get_size( data: PyTree[Any], batch_axis: PyTree[int | None] @@ -78,87 +83,28 @@ def broadcast_and_get_size( @typing.runtime_checkable -class BatchGenerator(Protocol): +class Batcher(Protocol): + @abstractmethod def __call__( self, data: PyTree[Any], batch_size: int, - batch_axis: PyTree[int | None], + batch_axis: int, *, key: PRNGKeyArray, - ) -> Generator[PyTree[Any], None, None]: - raise NotImplementedError + ) -> BatchGenerator: + pass def batch_data( data: PyTree[Any], - batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, + batch_size: int, + batch_axis: int, convert_to_numpy: bool = True, *, - key: PRNGKeyArray, -) -> Generator[PyTree[Any], None, None]: - """Create a `Generator` that draws subsets of data without replacement. - - The data can be any `PyTree` with `ArrayLike` leaves. If `batch_axis` is - passed, batch axes (including `None` for no batching) can be specified for - every leaf individualy. - A generator is returned that indefinetly yields batches of data with size - `batch_size`. Examples are drawn without replacement until the remaining - dataset is smaller than `batch_size`, at which point the dataset will be - reshuffeld and the process starts over. - - Example: - This is an example for a nested `PyTree`, where the elements x and y - have batch dimension along the first axis. - - ```python - >>> import klax - >>> import jax - >>> import jax.numpy as jnp - >>> - >>> x = jnp.array([1., 2.]) - >>> y = jnp.array([[1.], [2.]]) - >>> data = (x, {"a": 1.0, "b": y}) - >>> batch_axis = (0, {"a": None, "b": 0}) - >>> iter_data = klax.batch_data( - ... data, - ... 32, - ... batch_axis, - ... key=jax.random.key(0) - ... ) - >>> - ``` - - Args: - data: The data that shall be batched. It can be any `PyTree` with - `ArrayLike` leaves. - batch_size: The number of examples in a batch. - batch_axis: PyTree of the batch axis indices. `None` is used to - indicate that the corresponding leaf or subtree in data does not - have a batch axis. `batch_axis` must have the same structure as - `data` or have `data` as a prefix. - (Defaults to 0, meaning all leaves in `data` are - batched along their first dimension.) - convert_to_numpy: If `True`, batched data leafs will be converted to - Numpy arrays before batching. This is useful for performance - reasons, as Numpy's slicing is much faster than JAX's. - key: A `jax.random.PRNGKey` used to provide randomness for batch - generation. (Keyword only argument.) - - Returns: - A `Generator` that yields a random batch of data every time is is - called. - - Yields: - A `PyTree[ArrayLike]` with the same structure as `data`. Where all - batched leaves have `batch_size`. - - Note: - Note that if the size of the dataset is smaller than `batch_size`, the - obtained batches will have dataset size. - - """ + key: PRNGKeyArray, # Only cor compliance with the `Batcher` protocol +) -> BatchGenerator: + """Create a stateful batch generator that uses the step as seed.""" batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) # Convert to Numpy arrays. Numpy's slicing is much faster than JAX's, so @@ -177,6 +123,10 @@ def batch_data( indices = jnp.arange(dataset_size) while True: + # Store the training state as received by the `.send(state)` within + # the training loop. + ctx: TrainingContext = yield + key = jax.random.PRNGKey(ctx.step) # Create key from step perm = jr.permutation(key, indices) (key,) = jr.split(key, 1) # Update key start, end = 0, batch_size @@ -192,6 +142,121 @@ def batch_data( end = start + batch_size +# @typing.runtime_checkable +# class BatchGenerator(Protocol): +# def __call__( +# self, +# data: PyTree[Any], +# batch_size: int, +# batch_axis: PyTree[int | None], +# *, +# key: PRNGKeyArray, +# ) -> Generator[PyTree[Any], None, None]: +# raise NotImplementedError + + +# def batch_data( +# data: PyTree[Any], +# batch_size: int = 32, +# batch_axis: PyTree[int | None] = 0, +# convert_to_numpy: bool = True, +# *, +# key: PRNGKeyArray, +# ) -> Generator[PyTree[Any], None, None]: +# """Create a `Generator` that draws subsets of data without replacement. + +# The data can be any `PyTree` with `ArrayLike` leaves. If `batch_axis` is +# passed, batch axes (including `None` for no batching) can be specified for +# every leaf individualy. +# A generator is returned that indefinetly yields batches of data with size +# `batch_size`. Examples are drawn without replacement until the remaining +# dataset is smaller than `batch_size`, at which point the dataset will be +# reshuffeld and the process starts over. + +# Example: +# This is an example for a nested `PyTree`, where the elements x and y +# have batch dimension along the first axis. + +# ```python +# >>> import klax +# >>> import jax +# >>> import jax.numpy as jnp +# >>> +# >>> x = jnp.array([1., 2.]) +# >>> y = jnp.array([[1.], [2.]]) +# >>> data = (x, {"a": 1.0, "b": y}) +# >>> batch_axis = (0, {"a": None, "b": 0}) +# >>> iter_data = klax.batch_data( +# ... data, +# ... 32, +# ... batch_axis, +# ... key=jax.random.key(0) +# ... ) +# >>> +# ``` + +# Args: +# data: The data that shall be batched. It can be any `PyTree` with +# `ArrayLike` leaves. +# batch_size: The number of examples in a batch. +# batch_axis: PyTree of the batch axis indices. `None` is used to +# indicate that the corresponding leaf or subtree in data does not +# have a batch axis. `batch_axis` must have the same structure as +# `data` or have `data` as a prefix. +# (Defaults to 0, meaning all leaves in `data` are +# batched along their first dimension.) +# convert_to_numpy: If `True`, batched data leafs will be converted to +# Numpy arrays before batching. This is useful for performance +# reasons, as Numpy's slicing is much faster than JAX's. +# key: A `jax.random.PRNGKey` used to provide randomness for batch +# generation. (Keyword only argument.) + +# Returns: +# A `Generator` that yields a random batch of data every time is is +# called. + +# Yields: +# A `PyTree[ArrayLike]` with the same structure as `data`. Where all +# batched leaves have `batch_size`. + +# Note: +# Note that if the size of the dataset is smaller than `batch_size`, the +# obtained batches will have dataset size. + +# """ +# batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) + +# # Convert to Numpy arrays. Numpy's slicing is much faster than JAX's, so +# # for fast model training steps this actually makes a huge difference! +# # However, be aware that this is likely only true if JAX runs on CPU. +# if convert_to_numpy: +# data = jax.tree.map( +# lambda x, a: x if a is None else np.array(x), +# data, +# batch_axis, +# is_leaf=lambda x: x is None, +# ) + +# # Reduce batch size if the dataset has less examples than batch size +# batch_size = min(batch_size, dataset_size) + +# indices = jnp.arange(dataset_size) +# while True: +# perm = jr.permutation(key, indices) +# (key,) = jr.split(key, 1) # Update key +# start, end = 0, batch_size +# while end <= dataset_size: +# batch_perm = perm[start:end] +# yield jax.tree.map( +# lambda a, x: x if a is None else x[batch_perm], +# batch_axis, +# data, +# is_leaf=lambda x: x is None, +# ) +# start = end +# end = start + batch_size + + def split_data( data: PyTree[Any], proportions: Sequence[int | float], diff --git a/klax/_losses.py b/klax/_losses.py index 5e5b7b1..184f9de 100644 --- a/klax/_losses.py +++ b/klax/_losses.py @@ -14,109 +14,65 @@ import typing from abc import abstractmethod -from collections.abc import Sequence from typing import Any, Protocol +import equinox as eqx import jax import jax.numpy as jnp from jaxtyping import PyTree, Scalar @typing.runtime_checkable -class Loss(Protocol): - """An abstract callable loss object. +class ValueFn(Protocol): + @abstractmethod + def __call__(self, model: PyTree, data: PyTree) -> Scalar: + pass - It can be used to build custom losses that can be passed to [`klax.fit`][]. - Example: - A simple custom loss that computes the mean squared error between - the predicted values `y_pred` and true values `y` for in inputs `x` may - be implemented as follows: +@typing.runtime_checkable +class ValueAndGradFn(Protocol): + @abstractmethod + def __call__(self, model: PyTree, data: PyTree) -> tuple[Scalar, PyTree]: + pass - ```python - >>> def mse(model, data, batch_axis=0): - ... x, y = data - ... if isinstance(batch_axis, tuple): - ... in_axes = batch_axis[0] - ... else: - ... in_axes = batch_axis - ... y_pred = jax.vmap(model, in_axes=(in_axes,))(x) - ... return jnp.mean(jnp.square(y_pred - y)) - ``` - Note that, since we a aim to provide a maximum of flexibility the users - have to take care of applying `jax.vmap` to the model themselves. +class LossFactory(Protocol): + @abstractmethod + def __call__(self, batch_axis) -> tuple[ValueFn, ValueAndGradFn]: + pass - """ - @abstractmethod - def __call__( - self, - model: PyTree, - data: PyTree, - batch_axis: int | None | Sequence[Any], - ) -> Scalar: - """Abstract method to compute the loss for a given model and data. - - Args: - model: The model parameters or structure to evaluate the loss. - data: The input data or structure used for loss computation. - batch_axis: Specifies the axis or axes corresponding to the batch - dimension in the data. Can be an integer, None, or a sequence - of values. - - Returns: - Scalar: The computed loss value. - - """ - ... - - -class MSE(Loss): - """Mean squared error for a tuple of data `(x, y)`. - - The inputs `x` and the outputs `y` are expected to have the same batch axis - and equal length along that axis. - """ - - def __call__( - self, - model: PyTree, - data: PyTree, - batch_axis: int | None | Sequence[Any] = 0, - ) -> Scalar: +def mse(batch_axis) -> tuple[ValueFn, ValueAndGradFn]: + def value_fn(model: PyTree, data: PyTree) -> Scalar: x, y = data if isinstance(batch_axis, tuple): in_axes = batch_axis[0] else: in_axes = batch_axis - y_pred = jax.vmap(model, in_axes=(in_axes,))(x) + y_pred = eqx.filter_vmap(model, in_axes=(in_axes,))(x) return jnp.mean(jnp.square(y_pred - y)) + def value_and_grad_fn( + model: PyTree, data: PyTree + ) -> tuple[Scalar, PyTree]: + return eqx.filter_value_and_grad(value_fn)(model, data) -mse = MSE() - - -class MAE(Loss): - """Mean absolute error for a tuple of data `(x, y)`. + return value_fn, value_and_grad_fn - The inputs `x` and the outputs `y` are expected to have the same batch axis - and equal length along that axis. - """ - def __call__( - self, - model: PyTree, - data: PyTree, - batch_axis: int | None | Sequence[Any] = 0, - ) -> Scalar: +def mae(batch_axis) -> tuple[ValueFn, ValueAndGradFn]: + def value_fn(model: PyTree, data: PyTree) -> Scalar: x, y = data if isinstance(batch_axis, tuple): in_axes = batch_axis[0] else: in_axes = batch_axis - y_pred = jax.vmap(model, in_axes=(in_axes,))(x) + y_pred = eqx.filter_vmap(model, in_axes=(in_axes,))(x) return jnp.mean(jnp.abs(y_pred - y)) + def value_and_grad_fn( + model: PyTree, data: PyTree + ) -> tuple[Scalar, PyTree]: + return eqx.filter_value_and_grad(value_fn)(model, data) -mae = MAE() + return value_fn, value_and_grad_fn diff --git a/klax/_new_training.py b/klax/_new_training.py deleted file mode 100644 index bf85d4f..0000000 --- a/klax/_new_training.py +++ /dev/null @@ -1,511 +0,0 @@ -# Copyright 2025 The Klax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Implements a training loop.""" - -import typing -from abc import ABC, abstractmethod -from collections.abc import Callable, Generator, Iterable -from dataclasses import dataclass, field -from typing import Any, Protocol, Self - -import equinox as eqx -import jax -import jax.numpy as jnp -import jax.random as jr -import numpy as np -import optax -from jaxtyping import PRNGKeyArray, PyTree, Scalar - -from ._datahandler import broadcast_and_get_size -from ._wrappers import apply - - -class TrainingState: - _flat_model: PyTree - _flat_opt_state: PyTree - _treedef_model: PyTree - _treedef_opt_state: PyTree - _step: int - _cache: dict[str, Any] = {} - - def __init__( - self, model: PyTree, opt_state: PyTree = None, initial_step: int = 0 - ): - # Apply the Constraint in the model to ensure apply-constrains are met - # initially - model = apply(model) - - # Use the unflatten trick to speed up training, - # see https://docs.kidger.site/equinox/tricks/ - flat_model, treedef_model = jax.tree.flatten(model) - flat_opt_state, treedef_opt_state = jax.tree.flatten(opt_state) - - self._flat_model = flat_model - self._flat_opt_state = flat_opt_state - self._treedef_model = treedef_model - self._treedef_opt_state = treedef_opt_state - self._step = initial_step - - @staticmethod - def _lazy_chached_property(fun: Callable) -> property: - """Turn a public method into a lazily evaluated property. - - The return value of ``fun`` is stored in the ``_cache`` dictionary of - the current object using the function name as key. If the name is - already in ``_cache`` then the cached value is simply returned, - without evaluating ``fun``. - - Args: - fun: Method to wrap. - - Returns: - Wrapped method as a property. - - """ - attr_name = fun.__name__ - - def wrapper(self: Self): - if attr_name not in self._cache: - self._cache.setdefault(attr_name, fun(self)) - return self._cache.get(attr_name) - - wrapper.__doc__ = fun.__doc__ - - return property(wrapper) - - @property - def flat_model(self) -> PyTree: - return self._flat_model - - @property - def flat_opt_state(self) -> PyTree: - return self._flat_opt_state - - @property - def step(self) -> int: - return self._step - - @_lazy_chached_property - def model(self) -> PyTree: - return jax.tree_util.tree_unflatten( - self._treedef_model, self._flat_model - ) - - @_lazy_chached_property - def opt_state(self) -> PyTree: - return jax.tree_util.tree_unflatten( - self._treedef_opt_state, self._flat_opt_state - ) - - @property - def treedef_model(self) -> PyTree: - return self._treedef_model - - @property - def treedef_opt_state(self) -> PyTree: - return self._treedef_opt_state - - def update(self, flat_model: PyTree, flat_opt_state: PyTree): - self._flat_model = flat_model - self._flat_opt_state = flat_opt_state - self._step += self._step - - # Clear cache - self._cache.clear() - - -@typing.runtime_checkable -class ValueFn(Protocol): - @abstractmethod - def __call__(self, model: PyTree, data: PyTree) -> Scalar: - pass - - -@typing.runtime_checkable -class ValueAndGradFn(Protocol): - @abstractmethod - def __call__(self, model: PyTree, data: PyTree) -> tuple[Scalar, PyTree]: - pass - - -class LossFactory(Protocol): - @abstractmethod - def __call__(self, batch_axis) -> tuple[ValueFn, ValueAndGradFn]: - pass - - -def mse(batch_axis) -> tuple[ValueFn, ValueAndGradFn]: - def value_fn(model: PyTree, data: PyTree) -> Scalar: - x, y = data - if isinstance(batch_axis, tuple): - in_axes = batch_axis[0] - else: - in_axes = batch_axis - y_pred = eqx.filter_vmap(model, in_axes=(in_axes,))(x) - return jnp.mean(jnp.square(y_pred - y)) - - def value_and_grad_fn( - model: PyTree, data: PyTree - ) -> tuple[Scalar, PyTree]: - return eqx.filter_value_and_grad(value_fn)(model, data) - - return value_fn, value_and_grad_fn - - -class Updater(Protocol): - @abstractmethod - def __call__( - self, model: PyTree, batch: PyTree, opt_state: PyTree - ) -> tuple[PyTree, PyTree]: - pass - - -class UpdaterFactory(Protocol): - @abstractmethod - def __call__( - self, - opt_update: optax.TransformUpdateFn | optax.TransformUpdateExtraArgsFn, - value_fn: ValueFn, - value_and_grad_fn: ValueAndGradFn, - ) -> Updater: - pass - - -def optax_transform_update_fn_updater( - opt_update: optax.TransformUpdateFn, - value_fn: ValueFn, - value_and_grad_fn: ValueAndGradFn, -) -> Updater: - def wrapper(model, batch, opt_state): - _, grad = value_and_grad_fn(model, batch) - updates, opt_state = opt_update( - grad, - opt_state, - eqx.filter(model, eqx.is_inexact_array), - ) - model = eqx.apply_updates(model, updates) - return model, opt_state - - return wrapper - - -def optax_transform_update_fn_extra_args_updater( - opt_update: optax.TransformUpdateExtraArgsFn, - value_fn: ValueFn, - value_and_grad_fn: ValueAndGradFn, -) -> Updater: - def wrapper(model, batch, opt_state): - value, grad = value_and_grad_fn(model, batch) - updates, opt_state = opt_update( - grad, - opt_state, - eqx.filter(model, eqx.is_inexact_array), - value=value, - grad=grad, - value_fn=jax.tree_util.Partial(value_fn, model=model, batch=batch), - ) - model = eqx.apply_updates(model, updates) - return model, opt_state - - return wrapper - - -@dataclass -class EvaluationContext: - value_fn: ValueFn - data: PyTree[Any] - val_data: PyTree[Any] | None = None - _cached_step: int | None = None - _cache: dict[str, Any] = field(default_factory=dict) - - def _ensure_step(self, state: TrainingState): - if self._cached_step != state.step: - self._cache.clear() - self._cached_step = state.step - - @eqx.filter_jit - def _loss_impl(self, model: PyTree, batch: PyTree[Any]): - return self.value_fn(model, batch) - - def loss(self, state: TrainingState) -> Scalar: - self._ensure_step(state) - if "loss" not in self._cache: - self._cache["loss"] = self._loss_impl(state.model, self.data) - return self._cache["loss"] - - def val_loss(self, state: TrainingState) -> Scalar | None: - self._ensure_step(state) - if self.val_data is None: - return None - - if "val_loss" not in self._cache: - self._cache["val_loss"] = self._loss_impl( - state.model, self.val_data - ) - return self._cache["val_loss"] - - -@dataclass -class TrainingContext: - state: TrainingState - evaluator: EvaluationContext - - @property - def model(self) -> PyTree: - return self.state.model - - @property - def optimizer_state(self) -> PyTree: - return self.state.opt_state - - @property - def step(self) -> int: - return self.state.step - - @property - def loss(self) -> Scalar: - return self.evaluator.loss(self.state) - - @property - def val_loss(self) -> Scalar | None: - return self.evaluator.val_loss(self.state) - - -class Callback(ABC): - """An abstract callback. - - Inherit from this class to create a custom callback. - """ - - def __call__(self, ctx: TrainingContext) -> bool | None: - """Call after each step during training.""" - pass - - def on_training_end(self, ctx: TrainingContext) -> None: - """Call when training ends.""" - pass - - def on_training_start(self, ctx: TrainingContext) -> None: - """Call when training starts.""" - pass - - -@typing.runtime_checkable -class Batcher(Protocol): - @abstractmethod - def __call__( - self, - data: PyTree[Any], - batch_size: int, - batch_axis: int, - *, - key: PRNGKeyArray, - ) -> Generator[PyTree[Any], TrainingState, None]: - pass - - -def stateful_batch_data( - data: PyTree[Any], - batch_size: int, - batch_axis: int, - convert_to_numpy: bool = True, - *, - key: PRNGKeyArray, # Only cor compliance with the `Batcher` protocol -) -> Generator[PyTree[Any], TrainingContext, None]: - """Create a stateful batch generator that uses the step as seed.""" - batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) - - # Convert to Numpy arrays. Numpy's slicing is much faster than JAX's, so - # for fast model training steps this actually makes a huge difference! - # However, be aware that this is likely only true if JAX runs on CPU. - if convert_to_numpy: - data = jax.tree.map( - lambda x, a: x if a is None else np.array(x), - data, - batch_axis, - is_leaf=lambda x: x is None, - ) - - # Reduce batch size if the dataset has less examples than batch size - batch_size = min(batch_size, dataset_size) - - indices = jnp.arange(dataset_size) - while True: - # Store the training state as received by the `.send(state)` within - # the training loop. - ctx: TrainingContext = yield - key = jax.random.PRNGKey(ctx.state.step) # Create key from step - perm = jr.permutation(key, indices) - (key,) = jr.split(key, 1) # Update key - start, end = 0, batch_size - while end <= dataset_size: - batch_perm = perm[start:end] - yield jax.tree.map( - lambda a, x: x if a is None else x[batch_perm], - batch_axis, - data, - is_leaf=lambda x: x is None, - ) - start = end - end = start + batch_size - - -def fit_core[T: eqx.Module]( - updater: Updater, - batcher: Generator[PyTree[Any], TrainingContext, None], - ctx: TrainingContext, - steps: int, - callbacks: Iterable[Callback] | None = None, -): - @eqx.filter_jit - def make_step(batch, flat_model, flat_opt_state): - # Use the unflatten trick to speed up training, - # see https://docs.kidger.site/equinox/tricks/ - model = jax.tree_util.tree_unflatten( - ctx.state.treedef_model, flat_model - ) - opt_state = jax.tree_util.tree_unflatten( - ctx.state.treedef_opt_state, flat_opt_state - ) - - # Compute and apply the parameter updates - # params, static = eqx.partition(model, eqx.is_inexact_array) - # params = updater(params, static, batch, opt_state) - model, opt_state = updater(model, batch, opt_state) - - # Apply the Constraint in the model to ensure apply-constrains are met - # after the update. - model = apply(model) - - flat_model = jax.tree_util.tree_leaves(model) - flat_opt_state = jax.tree_util.tree_leaves(opt_state) - - return flat_model, flat_opt_state - - # Make callbacks iterable - callbacks = [] if callbacks is None else list(callbacks) - - for callback in callbacks: - callback.on_training_start(ctx) - - for _ in range(steps): - next(batcher) - batch = batcher.send(ctx) # Send the context - flat_model, flat_opt_state = make_step( - batch, ctx.state.flat_model, ctx.state.flat_opt_state - ) - - ctx.state.update(flat_model, flat_opt_state) - - # Run all callbacks and break if any of them request termination of - # the training loop. - # Note! The square brackets are important. Otherwise the loop is - # terminated with the first callback that returns true. But we want - # to run all callbacks first and then decide, whether to terminate. - if any([callback(ctx) for callback in callbacks]): - break - - # Call callbacks after training - for callback in callbacks: - callback.on_training_end(ctx) - - return ctx - - -def fit( - model, - data, - *, - batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, - validation_data: PyTree[Any] = None, - steps: int = 1_000, - loss: LossFactory = mse, - optimizer: optax.GradientTransformation, - init_opt_state: PyTree[Any] = None, - batcher: Batcher = stateful_batch_data, - updater: UpdaterFactory = optax_transform_update_fn_updater, - key: PRNGKeyArray, -): - value_fn, value_and_grad_fn = loss(batch_axis) - evaluator = EvaluationContext(value_fn, data, val_data=validation_data) - state = TrainingState( - model=model, - opt_state=optimizer.init(eqx.filter(model, eqx.is_inexact_array)) - if init_opt_state is None - else init_opt_state - if init_opt_state is None - else init_opt_state, - ) - ctx = TrainingContext( - state=state, - evaluator=evaluator, - ) - - ctx = fit_core( - updater(optimizer.update, value_fn, value_and_grad_fn), - batcher( - data=data, batch_axis=batch_axis, batch_size=batch_size, key=key - ), - ctx, - steps, - ) - return ctx.state.model - - -if __name__ == "__main__": - # Test fit - x = jnp.linspace(0.0, 1.0, 2).reshape(-1, 1) - y = 2.0 * x + 1.0 - model = eqx.nn.Linear(1, 1, key=eqx.internal.GetKey()()) - model = fit( - model, (x, y), optimizer=optax.adam(1.0), key=eqx.internal.GetKey()() - ) - y_pred = jax.vmap(model)(x) - assert jnp.allclose(y_pred, y) - - # Test fit_core - x = jnp.linspace(0.0, 1.0, 2).reshape(-1, 1) - y = 2.0 * x + 1.0 - data = (x, y) - model = eqx.nn.Linear(1, 1, key=eqx.internal.GetKey()()) - batch_axis = 0 - optimizer = optax.adam(1.0) - value_fn, value_and_grad_fn = mse(batch_axis) - evaluator = EvaluationContext(value_fn, (x, y)) - state = TrainingState( - model=model, - opt_state=optimizer.init(eqx.filter(model, eqx.is_inexact_array)), - ) - ctx = TrainingContext( - state=state, - evaluator=evaluator, - ) - batcher = stateful_batch_data( - (x, y), - batch_size=32, - batch_axis=batch_axis, - key=eqx.internal.GetKey()(), # Unused - ) - updater = optax_transform_update_fn_updater( - optimizer.update, value_fn, value_and_grad_fn - ) - state = fit_core(updater, batcher, ctx, steps=1000) - y_pred = jax.vmap(state.model)(x) - assert jnp.allclose(y_pred, y) - - import pprint - - pprint.pp(state) diff --git a/klax/_training.py b/klax/_training.py index 167348d..3cd9d95 100644 --- a/klax/_training.py +++ b/klax/_training.py @@ -12,182 +12,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Implements a basic training loop.""" +"""Implements a training loop.""" -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, overload +from collections.abc import Callable, Generator, Iterable +from typing import Any, Protocol, Self import equinox as eqx import jax +import jax.numpy as jnp +import numpy as np import optax -from jaxtyping import PRNGKeyArray, PyTree - -from ._callbacks import ( - Callback, - CallbackArgs, - HistoryCallback, +from jaxtyping import PRNGKeyArray, PyTree, Scalar +from optax._src.utils import Sequence + +from ._callbacks import Callback, HistoryCallback +from ._context import ( + EvaluationContext, + TimingInfo, + TrainingContext, + TrainingState, ) -from ._datahandler import ( - BatchGenerator, - batch_data, - broadcast_and_get_size, +from ._datahandler import Batcher, BatchGenerator, batch_data +from ._losses import LossFactory, mse +from ._updaters import ( + Updater, + UpdaterFactory, + optax_transform_update_fn_updater, ) -from ._losses import Loss, mse -from ._wrappers import apply, unwrap - - -@overload -def fit[T: eqx.Module]( - model: T, - data: PyTree[Any], - *, - batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, - validation_data: PyTree[Any] = None, - steps: int = 1000, - loss_fn: Loss = mse, - optimizer: optax.GradientTransformation = optax.adam(1e-3), - init_opt_state: PyTree[Any] = None, - batcher: BatchGenerator = batch_data, - history: None = None, - callbacks: Iterable[Callback] | None = None, - key: PRNGKeyArray, -) -> tuple[T, HistoryCallback]: ... -@overload -def fit[T: eqx.Module, H: Callback]( - model: T, - data: PyTree[Any], - *, - batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, - validation_data: PyTree[Any] = None, - steps: int = 1000, - loss_fn: Loss = mse, - optimizer: optax.GradientTransformation = optax.adam(1e-3), - init_opt_state: PyTree[Any] = None, - batcher: BatchGenerator = batch_data, - history: H, - callbacks: Iterable[Callback] | None = None, - key: PRNGKeyArray, -) -> tuple[T, H]: ... -def fit[T: eqx.Module, H: Callback]( - model: T, - data: PyTree[Any], - *, - batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, - validation_data: PyTree[Any] = None, - steps: int = 1000, - loss_fn: Loss = mse, - optimizer: optax.GradientTransformation = optax.adam(1e-3), - init_opt_state: PyTree[Any] = None, - batcher: BatchGenerator = batch_data, - history: HistoryCallback | H | None = None, - callbacks: Iterable[Callback] | None = None, - key: PRNGKeyArray, -) -> tuple[T, HistoryCallback | H]: - """Trains a model using an optimizer from optax. - - Args: - model: The model instance, which should be trained. It must be a - subclass of `equinox.Module`. The model may contain - [`klax.Unwrappable`][] wrappers. - data: The training data can be any `PyTree` with `ArrayLike` leaves. - Most likely you'll want `data` to be a tuple `(x, y)` with model - inputs `x` and model outputs `y`. - batch_size: The number of examples in a batch. - batch_axis: A `PyTree` denoting, which axis is the batch axis for - arrays in `data`. `batch_axis` must be a prefix of `data`. By - specifying `batch_axis` as a `PyTree` it is possible to specify - different batch axes for different leaves of `data`. (Defaults to - `0`, meaning the first axes of arrays in `data` are batch - dimensions.) - validation_data: Arbitrary `PyTree` used for validation during - training. Must have the same tree structure as `data`. (Defaults - to None.) - steps: Number of gradient updates to apply. (Defaults to 1000.) - loss_fn: The loss function with call signature - `(model: PyTree, data: PyTree, batch_axis: int | None | - Sequence[Any]) -> float`. (Defaults to `mse`.) - optimizer: The optimizer. Any optax gradient transform to calculate - the updates for the model. (Defaults to optax.adam(1e-3).) - init_opt_state: The initial state of the optimizer. If `None`, the - optimizer is initialized from scratch. By providing a value for - `init_opt_state`, the user can resume training from a previous - state (e.g., obtained from the `HistoryCallback.last_opt_state`). - (Defaults to `None`.) - batcher: The data loader that splits inputs and targets into batches. - (Defaults to `batch_data`.) - history: A callback intended for tracking the training process. If no - custom callback is passed the [`klax.HistoryCallback`][] with a - logging interval of 100 steps is used. To change the logging - increment or verbosity of this default callback, pass a - `HistoryCallback` object to this argument, e.g., - `history=HistoryCallback(log_every=10, verbose=False)` for logging - on every 10-th step without printing the loss. - callbacks: Callback functions that are evaluated after every training - step. They can be used to implement early stopping, custom history - logging and more. The argument to the callback function is a - CallbackArgs object. (Defaults to `None`. Keyword only Argument) - key: A `jax.random.PRNGKey` used to provide randomness for batch - generation. (Keyword only argument.) +from ._wrappers import apply - Note: - This function assumes that the batch dimension is always oriented along - the first axes of any `jax.Array` - - Returns: - A tuple of the trained model and the loss history. - - """ - # Braodcast the batch_axis to the data. While this happens again in the - # batch_data, doing it here allows the use of the broadcasted batch_axis in - # the loss function. If `batch_axis` is a prefix of `data`, this ensures - # that only leafs of type ArrayLike are vmapped. Thus it is possible to - # have data like `(str, array)` ans still use `batch_axis=0` instead of - # `batch_axis=(None, 0)`. - batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) - - # Define a function to calculate the loss. This is jit compiled to speed up - # the loss evaluation for the loss history. - @eqx.filter_jit - def combined_loss(model, batch): - model = unwrap(model) - return loss_fn(model, batch, batch_axis=batch_axis) - - # This partitioned loss function is required within the make_step function, - # because the optax.lbgfs GradientTransformation required the loss function - # to be diretly dependent on the parameters. - def partitioned_loss(params, static, batch): - model = eqx.combine(params, static) - return combined_loss(model, batch) +def fit_core( + updater: Updater, + batcher: BatchGenerator, + ctx: TrainingContext, + steps: int, + callbacks: Iterable[Callback] | Callback | None = None, +) -> tuple[TrainingContext, list[Callback]]: @eqx.filter_jit - def make_step(batch, flat_model, optimizer, flat_opt_state): + def make_step(batch, flat_model, flat_opt_state): # Use the unflatten trick to speed up training, # see https://docs.kidger.site/equinox/tricks/ - model = jax.tree_util.tree_unflatten(treedef_model, flat_model) + model = jax.tree_util.tree_unflatten(ctx.treedef_model, flat_model) opt_state = jax.tree_util.tree_unflatten( - treedef_opt_state, flat_opt_state + ctx.treedef_opt_state, flat_opt_state ) # Compute and apply the parameter updates - params, static = eqx.partition(model, eqx.is_inexact_array) - value, grad = jax.value_and_grad(partitioned_loss)( - params, static, batch - ) - updates, opt_state = optimizer.update( - grad, - opt_state, - params, - value=value, - grad=grad, - value_fn=jax.tree_util.Partial( - partitioned_loss, static=static, batch=batch - ), - ) - params = optax.apply_updates(params, updates) - model = eqx.combine(params, static) + # params, static = eqx.partition(model, eqx.is_inexact_array) + # params = updater(params, static, batch, opt_state) + model, opt_state = updater(model, batch, opt_state) # Apply the Constraint in the model to ensure apply-constrains are met # after the update. @@ -198,65 +72,129 @@ def make_step(batch, flat_model, optimizer, flat_opt_state): return flat_model, flat_opt_state - if init_opt_state is None: - # Initialize the optimizer and 'tell it' to optimize with respect to - # all inexact arrays in the model. This is done by passing the model to - # the optimizer. - opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array)) - else: - opt_state = init_opt_state - - # Apply the Constraint in the model to ensure apply-constrains are met - # initially - model = apply(model) - - # Use the unflatten trick to speed up training, - # see https://docs.kidger.site/equinox/tricks/ - flat_model, treedef_model = jax.tree.flatten(model) - flat_opt_state, treedef_opt_state = jax.tree.flatten(opt_state) - # Make callbacks iterable - callbacks = [] if callbacks is None else list(callbacks) - - # Initialize callback arguments and history - if history is None: - history = HistoryCallback(log_every=100) - callbacks.append(history) - - cbargs = CallbackArgs( - combined_loss, treedef_model, treedef_opt_state, data, validation_data - ) + if callbacks is None: + callbacks = [] + elif isinstance(callbacks, Callback): + callbacks = [callbacks] + else: + callback = list(callbacks) - # Call callbacks after training - cbargs.update(flat_model, flat_opt_state, 0) + # Initialize context and callbacks + ctx.update(ctx.flat_model, ctx.flat_opt_state) for callback in callbacks: - callback.on_training_start(cbargs) + callback.on_training_start(ctx) + + for _ in range(steps): + next(batcher) # Prime the batcher + batch = batcher.send(ctx) # Send the context - # Loop over all training steps - for step, batch in zip( - range(1, steps + 1), - batcher(data, batch_size, batch_axis, key=key), - ): flat_model, flat_opt_state = make_step( - batch, flat_model, optimizer, flat_opt_state + batch, ctx.flat_model, ctx.flat_opt_state ) - # Update callbacks arguments with the current state of the model - cbargs.update(flat_model, flat_opt_state, step) + ctx.update(flat_model, flat_opt_state) # Run all callbacks and break if any of them request termination of # the training loop. # Note! The square brackets are important. Otherwise the loop is # terminated with the first callback that returns true. But we want # to run all callbacks first and then decide, whether to terminate. - if any([callback(cbargs) for callback in callbacks]): + if any([callback(ctx) for callback in callbacks]): break - model = jax.tree_util.tree_unflatten(treedef_model, flat_model) - # Call callbacks after training - cbargs.update(flat_model, flat_opt_state, -1) + ctx.update(ctx.flat_model, ctx.flat_opt_state) for callback in callbacks: - callback.on_training_end(cbargs) + callback.on_training_end(ctx) + + return ctx, list(callbacks) + + +def fit[T: eqx.Module]( + model: T, + data, + *, + batch_size: int = 32, + batch_axis: PyTree[int | None] = 0, + validation_data: PyTree[Any] = None, + steps: int = 1_000, + loss: LossFactory = mse, + optimizer: optax.GradientTransformation, + init_opt_state: PyTree[Any] = None, + batcher: Batcher = batch_data, + updater: UpdaterFactory = optax_transform_update_fn_updater, + callbacks: Iterable[Callback] | Callback | None = HistoryCallback(), + key: PRNGKeyArray, +) -> tuple[T, list[Callback]]: + value_fn, value_and_grad_fn = loss(batch_axis) + evaluator = EvaluationContext(value_fn, data, val_data=validation_data) + state = TrainingState( + model=model, + opt_state=optimizer.init(eqx.filter(model, eqx.is_inexact_array)) + if init_opt_state is None + else init_opt_state + if init_opt_state is None + else init_opt_state, + ) + ctx = TrainingContext( + state=state, + evaluator=evaluator, + timing=TimingInfo(), + ) + + ctx, callbacks = fit_core( + updater(optimizer.update, value_fn, value_and_grad_fn), + batcher( + data=data, batch_axis=batch_axis, batch_size=batch_size, key=key + ), + ctx, + steps, + callbacks=callbacks, + ) + return ctx.model, callbacks + + +if __name__ == "__main__": + # Test fit + x = jnp.linspace(0.0, 1.0, 2).reshape(-1, 1) + y = 2.0 * x + 1.0 + model = eqx.nn.Linear(1, 1, key=eqx.internal.GetKey()()) + model, _ = fit( + model, (x, y), optimizer=optax.adam(1.0), key=eqx.internal.GetKey()() + ) + y_pred = jax.vmap(model)(x) + assert jnp.allclose(y_pred, y) + + # Test fit_core + x = jnp.linspace(0.0, 1.0, 2).reshape(-1, 1) + y = 2.0 * x + 1.0 + data = (x, y) + model = eqx.nn.Linear(1, 1, key=eqx.internal.GetKey()()) + batch_axis = 0 + optimizer = optax.adam(1.0) + value_fn, value_and_grad_fn = mse(batch_axis) + evaluator = EvaluationContext(value_fn, (x, y)) + state = TrainingState( + model=model, + opt_state=optimizer.init(eqx.filter(model, eqx.is_inexact_array)), + ) + ctx = TrainingContext( + state=state, evaluator=evaluator, timing=TimingInfo() + ) + batcher = batch_data( + (x, y), + batch_size=32, + batch_axis=batch_axis, + key=eqx.internal.GetKey()(), # Unused + ) + updater = optax_transform_update_fn_updater( + optimizer.update, value_fn, value_and_grad_fn + ) + ctx, _ = fit_core(updater, batcher, ctx, steps=1000) + y_pred = jax.vmap(ctx.model)(x) + assert jnp.allclose(y_pred, y) + + import pprint - return model, history + pprint.pp(state) diff --git a/klax/_updaters.py b/klax/_updaters.py new file mode 100644 index 0000000..475fc41 --- /dev/null +++ b/klax/_updaters.py @@ -0,0 +1,67 @@ +from abc import abstractmethod +from typing import Protocol + +import equinox as eqx +import jax +import optax +from jaxtyping import PyTree + +from ._losses import ValueAndGradFn, ValueFn + + +class Updater(Protocol): + @abstractmethod + def __call__( + self, model: PyTree, batch: PyTree, opt_state: PyTree + ) -> tuple[PyTree, PyTree]: + pass + + +class UpdaterFactory(Protocol): + @abstractmethod + def __call__( + self, + opt_update: optax.TransformUpdateFn | optax.TransformUpdateExtraArgsFn, + value_fn: ValueFn, + value_and_grad_fn: ValueAndGradFn, + ) -> Updater: + pass + + +def optax_transform_update_fn_updater( + opt_update: optax.TransformUpdateFn, + value_fn: ValueFn, + value_and_grad_fn: ValueAndGradFn, +) -> Updater: + def wrapper(model, batch, opt_state): + _, grad = value_and_grad_fn(model, batch) + updates, opt_state = opt_update( + grad, + opt_state, + eqx.filter(model, eqx.is_inexact_array), + ) + model = eqx.apply_updates(model, updates) + return model, opt_state + + return wrapper + + +def optax_transform_update_fn_extra_args_updater( + opt_update: optax.TransformUpdateExtraArgsFn, + value_fn: ValueFn, + value_and_grad_fn: ValueAndGradFn, +) -> Updater: + def wrapper(model, batch, opt_state): + value, grad = value_and_grad_fn(model, batch) + updates, opt_state = opt_update( + grad, + opt_state, + eqx.filter(model, eqx.is_inexact_array), + value=value, + grad=grad, + value_fn=jax.tree_util.Partial(value_fn, model=model, batch=batch), + ) + model = eqx.apply_updates(model, updates) + return model, opt_state + + return wrapper From 4c2b93096e4bb695cc8d276169cef54d37401d10 Mon Sep 17 00:00:00 2001 From: jaosch Date: Sat, 13 Sep 2025 16:59:26 +0200 Subject: [PATCH 9/9] Set batch_data back to its original functionality The dependency on the training step, which was introduced as a test has been removed. --- klax/_datahandler.py | 116 ------------------------------------------- 1 file changed, 116 deletions(-) diff --git a/klax/_datahandler.py b/klax/_datahandler.py index 05b0514..f446ec9 100644 --- a/klax/_datahandler.py +++ b/klax/_datahandler.py @@ -126,7 +126,6 @@ def batch_data( # Store the training state as received by the `.send(state)` within # the training loop. ctx: TrainingContext = yield - key = jax.random.PRNGKey(ctx.step) # Create key from step perm = jr.permutation(key, indices) (key,) = jr.split(key, 1) # Update key start, end = 0, batch_size @@ -142,121 +141,6 @@ def batch_data( end = start + batch_size -# @typing.runtime_checkable -# class BatchGenerator(Protocol): -# def __call__( -# self, -# data: PyTree[Any], -# batch_size: int, -# batch_axis: PyTree[int | None], -# *, -# key: PRNGKeyArray, -# ) -> Generator[PyTree[Any], None, None]: -# raise NotImplementedError - - -# def batch_data( -# data: PyTree[Any], -# batch_size: int = 32, -# batch_axis: PyTree[int | None] = 0, -# convert_to_numpy: bool = True, -# *, -# key: PRNGKeyArray, -# ) -> Generator[PyTree[Any], None, None]: -# """Create a `Generator` that draws subsets of data without replacement. - -# The data can be any `PyTree` with `ArrayLike` leaves. If `batch_axis` is -# passed, batch axes (including `None` for no batching) can be specified for -# every leaf individualy. -# A generator is returned that indefinetly yields batches of data with size -# `batch_size`. Examples are drawn without replacement until the remaining -# dataset is smaller than `batch_size`, at which point the dataset will be -# reshuffeld and the process starts over. - -# Example: -# This is an example for a nested `PyTree`, where the elements x and y -# have batch dimension along the first axis. - -# ```python -# >>> import klax -# >>> import jax -# >>> import jax.numpy as jnp -# >>> -# >>> x = jnp.array([1., 2.]) -# >>> y = jnp.array([[1.], [2.]]) -# >>> data = (x, {"a": 1.0, "b": y}) -# >>> batch_axis = (0, {"a": None, "b": 0}) -# >>> iter_data = klax.batch_data( -# ... data, -# ... 32, -# ... batch_axis, -# ... key=jax.random.key(0) -# ... ) -# >>> -# ``` - -# Args: -# data: The data that shall be batched. It can be any `PyTree` with -# `ArrayLike` leaves. -# batch_size: The number of examples in a batch. -# batch_axis: PyTree of the batch axis indices. `None` is used to -# indicate that the corresponding leaf or subtree in data does not -# have a batch axis. `batch_axis` must have the same structure as -# `data` or have `data` as a prefix. -# (Defaults to 0, meaning all leaves in `data` are -# batched along their first dimension.) -# convert_to_numpy: If `True`, batched data leafs will be converted to -# Numpy arrays before batching. This is useful for performance -# reasons, as Numpy's slicing is much faster than JAX's. -# key: A `jax.random.PRNGKey` used to provide randomness for batch -# generation. (Keyword only argument.) - -# Returns: -# A `Generator` that yields a random batch of data every time is is -# called. - -# Yields: -# A `PyTree[ArrayLike]` with the same structure as `data`. Where all -# batched leaves have `batch_size`. - -# Note: -# Note that if the size of the dataset is smaller than `batch_size`, the -# obtained batches will have dataset size. - -# """ -# batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) - -# # Convert to Numpy arrays. Numpy's slicing is much faster than JAX's, so -# # for fast model training steps this actually makes a huge difference! -# # However, be aware that this is likely only true if JAX runs on CPU. -# if convert_to_numpy: -# data = jax.tree.map( -# lambda x, a: x if a is None else np.array(x), -# data, -# batch_axis, -# is_leaf=lambda x: x is None, -# ) - -# # Reduce batch size if the dataset has less examples than batch size -# batch_size = min(batch_size, dataset_size) - -# indices = jnp.arange(dataset_size) -# while True: -# perm = jr.permutation(key, indices) -# (key,) = jr.split(key, 1) # Update key -# start, end = 0, batch_size -# while end <= dataset_size: -# batch_perm = perm[start:end] -# yield jax.tree.map( -# lambda a, x: x if a is None else x[batch_perm], -# batch_axis, -# data, -# is_leaf=lambda x: x is None, -# ) -# start = end -# end = start + batch_size - - def split_data( data: PyTree[Any], proportions: Sequence[int | float],