From 265c3f20faf0629b68e049a04c4207a9831096c7 Mon Sep 17 00:00:00 2001 From: jaosch Date: Fri, 17 Oct 2025 10:15:22 +0200 Subject: [PATCH 1/6] Changed model type for `fit` to PyTree This is more flexible, because in general the model can be any PyTree, e.g., a simple function, and not just an `equinox.Module`. --- klax/_training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/klax/_training.py b/klax/_training.py index 167348d..aa8fab0 100644 --- a/klax/_training.py +++ b/klax/_training.py @@ -37,7 +37,7 @@ @overload -def fit[T: eqx.Module]( +def fit[T: PyTree[Any]]( model: T, data: PyTree[Any], *, @@ -54,7 +54,7 @@ def fit[T: eqx.Module]( key: PRNGKeyArray, ) -> tuple[T, HistoryCallback]: ... @overload -def fit[T: eqx.Module, H: Callback]( +def fit[T: PyTree[Any], H: Callback]( model: T, data: PyTree[Any], *, @@ -70,7 +70,7 @@ def fit[T: eqx.Module, H: Callback]( callbacks: Iterable[Callback] | None = None, key: PRNGKeyArray, ) -> tuple[T, H]: ... -def fit[T: eqx.Module, H: Callback]( +def fit[T: PyTree[Any], H: Callback]( model: T, data: PyTree[Any], *, From 86d2e22a9385de8c1c9931fd6ea0509bf13a4b7d Mon Sep 17 00:00:00 2001 From: jaosch Date: Fri, 17 Oct 2025 10:19:46 +0200 Subject: [PATCH 2/6] Added new `optimistic_adam_v2` to testing suite. This will soon replace the old `optimistic_adam` optimizer. --- tests/test_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_training.py b/tests/test_training.py index 154226b..0285307 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -132,6 +132,7 @@ def __call__(self, cbargs: klax.CallbackArgs): optax.novograd(1.0), optax.optimistic_gradient_descent(1.0), optax.optimistic_adam(1.0), + optax.optimistic_adam_v2(1.0), optax.polyak_sgd(1.0), optax.radam(1.0), optax.rmsprop(1.0), From bbebf59c0059c8887c87c04a7d2585148647c86e Mon Sep 17 00:00:00 2001 From: jaosch Date: Fri, 17 Oct 2025 10:26:46 +0200 Subject: [PATCH 3/6] Added `PyTree` output type to `apply` and `finalize` We will see in practice how good this works. --- klax/_wrappers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/klax/_wrappers.py b/klax/_wrappers.py index 210e7cf..6740eec 100644 --- a/klax/_wrappers.py +++ b/klax/_wrappers.py @@ -369,7 +369,7 @@ def apply(self) -> Self: pass -def apply(tree: PyTree): +def apply(tree: PyTree) -> PyTree: """Map across a PyTree and apply all [Constraints][klax.Constraint]. This leaves all other nodes unchanged. @@ -437,7 +437,7 @@ def apply(self) -> Self: # ===----------------------------------------------------------------------===# -def finalize(tree: PyTree): +def finalize(tree: PyTree) -> PyTree: """Make a model containing [Constraints][klax.Constraint] callable. This function combined that functionalities of [`klax.apply`][] and From 856521f6cbe5a53fe8c40e63c6334382c09e260b Mon Sep 17 00:00:00 2001 From: jaosch Date: Fri, 17 Oct 2025 10:47:22 +0200 Subject: [PATCH 4/6] Consistent renaming of `batch_axis` to `batch_axes`. --- klax/_datahandler.py | 48 +++++++++++++++++++-------------------- klax/_losses.py | 28 +++++++++++------------ klax/_training.py | 30 ++++++++++++------------ tests/test_datahandler.py | 12 +++++----- tests/test_training.py | 2 +- 5 files changed, 60 insertions(+), 60 deletions(-) diff --git a/klax/_datahandler.py b/klax/_datahandler.py index eb300ce..93d4110 100644 --- a/klax/_datahandler.py +++ b/klax/_datahandler.py @@ -28,40 +28,40 @@ def broadcast_and_get_size( - data: PyTree[Any], batch_axis: PyTree[int | None] + data: PyTree[Any], batch_axes: PyTree[int | None] ) -> tuple[PyTree[int | None], int]: - """Broadcast `batch_axis` to the same structure as `data` and get size. + """Broadcast `batch_axes` to the same structure as `data` and get size. Args: data: PyTree of data. - batch_axis: PyTree of the batch axis indices. `None` is used to + batch_axes: PyTree of batch axis indices. `None` is used to indicate that the corresponding leaf or subtree in data does not - have a batch axis. `batch_axis` must have the same structure as + have a batch axis. `batch_axes` must have the same structure as `data` or have `data` as a prefix. Raises: - ValueError: If `batch_axis` is not a prefix of `data`. + ValueError: If `batch_axes` is not a prefix of `data`. ValueError: If not all batch axes have the same size. Returns: - Tuple of the broadcasted `batch_axis` and the `dataset_size`. + Tuple of the broadcasted `batch_axes` and the `dataset_size`. """ try: - batch_axis = jax.tree.map( + batch_axes = jax.tree.map( lambda a, d: jax.tree.map(eqx.if_array(a), d), - batch_axis, + batch_axes, data, is_leaf=lambda x: x is None, ) except ValueError as e: raise ValueError( - f"batch_axis must be a prefix of data. Original message: {e}" + f"batch_axes must be a prefix of data. Original message: {e}" ) dataset_sizes = jax.tree.map( lambda a, d: None if a is None else d.shape[a], - batch_axis, + batch_axes, data, is_leaf=lambda x: x is None, ) @@ -74,7 +74,7 @@ def broadcast_and_get_size( raise ValueError("All batched arrays must have equal batch sizes.") dataset_size = dataset_sizes[0] - return batch_axis, dataset_size + return batch_axes, dataset_size @typing.runtime_checkable @@ -83,7 +83,7 @@ def __call__( self, data: PyTree[Any], batch_size: int, - batch_axis: PyTree[int | None], + batch_axes: PyTree[int | None], *, key: PRNGKeyArray, ) -> Generator[PyTree[Any], None, None]: @@ -93,7 +93,7 @@ def __call__( def batch_data( data: PyTree[Any], batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, + batch_axes: PyTree[int | None] = 0, convert_to_numpy: bool = True, *, key: PRNGKeyArray, @@ -134,9 +134,9 @@ def batch_data( data: The data that shall be batched. It can be any `PyTree` with `ArrayLike` leaves. batch_size: The number of examples in a batch. - batch_axis: PyTree of the batch axis indices. `None` is used to + batch_axes: PyTree of batch axis indices. `None` is used to indicate that the corresponding leaf or subtree in data does not - have a batch axis. `batch_axis` must have the same structure as + have a batch axis. `batch_axes` must have the same structure as `data` or have `data` as a prefix. (Defaults to 0, meaning all leaves in `data` are batched along their first dimension.) @@ -159,7 +159,7 @@ def batch_data( obtained batches will have dataset size. """ - batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) + batch_axes, dataset_size = broadcast_and_get_size(data, batch_axes) # Convert to Numpy arrays. Numpy's slicing is much faster than JAX's, so # for fast model training steps this actually makes a huge difference! @@ -168,7 +168,7 @@ def batch_data( data = jax.tree.map( lambda x, a: x if a is None else np.array(x), data, - batch_axis, + batch_axes, is_leaf=lambda x: x is None, ) @@ -184,7 +184,7 @@ def batch_data( batch_perm = perm[start:end] yield jax.tree.map( lambda a, x: x if a is None else x[batch_perm], - batch_axis, + batch_axes, data, is_leaf=lambda x: x is None, ) @@ -195,14 +195,14 @@ def batch_data( def split_data( data: PyTree[Any], proportions: Sequence[int | float], - batch_axis: PyTree[int | None] = 0, + batch_axes: PyTree[int | None] = 0, *, key: PRNGKeyArray, ) -> tuple[PyTree[Any], ...]: """Split a `PyTree` of data into multiply randomly drawn subsets. This function is useful for splitting into training and test datasets. The - axis of the split if controlled by the `batch_axis` argument, which + axis of the split if controlled by the `batch_axes` argument, which specifies the batch axis for each leaf in `data`. Example: @@ -230,9 +230,9 @@ def split_data( proportions: Proportions of the split that will be applied to the data, e.g., `(80, 20)` for a 80% to 20% split. The proportions must be non-negative. - batch_axis: PyTree of the batch axis indices. `None` is used to + batch_axes: PyTree of batch axis indices. `None` is used to indicate that the corresponding leaf or subtree in data does not - have a batch axis. `batch_axis` must have the same structure as + have a batch axis. `batch_axes` must have the same structure as `data` or have `data` as a prefix. (Defaults to 0) key: A `jax.random.PRNGKey` used to provide randomness to the split. (Keyword only argument.) @@ -248,7 +248,7 @@ def split_data( raise ValueError("Proportions must be non-negative.") props = props / jnp.sum(props) - batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) + batch_axes, dataset_size = broadcast_and_get_size(data, batch_axes) indices = jnp.arange(dataset_size) perm = jr.permutation(key, indices) @@ -266,7 +266,7 @@ def get_subset(section): lambda a, d: d if a is None else jnp.take(d, section, axis=a, unique_indices=True), - batch_axis, + batch_axes, data, is_leaf=lambda x: x is None, ) diff --git a/klax/_losses.py b/klax/_losses.py index 5e5b7b1..906025f 100644 --- a/klax/_losses.py +++ b/klax/_losses.py @@ -34,12 +34,12 @@ class Loss(Protocol): be implemented as follows: ```python - >>> def mse(model, data, batch_axis=0): + >>> def mse(model, data, batch_axes=0): ... x, y = data - ... if isinstance(batch_axis, tuple): - ... in_axes = batch_axis[0] + ... if isinstance(batch_axes, tuple): + ... in_axes = batch_axes[0] ... else: - ... in_axes = batch_axis + ... in_axes = batch_axes ... y_pred = jax.vmap(model, in_axes=(in_axes,))(x) ... return jnp.mean(jnp.square(y_pred - y)) ``` @@ -54,14 +54,14 @@ def __call__( self, model: PyTree, data: PyTree, - batch_axis: int | None | Sequence[Any], + batch_axes: int | None | Sequence[Any], ) -> Scalar: """Abstract method to compute the loss for a given model and data. Args: model: The model parameters or structure to evaluate the loss. data: The input data or structure used for loss computation. - batch_axis: Specifies the axis or axes corresponding to the batch + batch_axes: Specifies the axis or axes corresponding to the batch dimension in the data. Can be an integer, None, or a sequence of values. @@ -83,13 +83,13 @@ def __call__( self, model: PyTree, data: PyTree, - batch_axis: int | None | Sequence[Any] = 0, + batch_axes: int | None | Sequence[Any] = 0, ) -> Scalar: x, y = data - if isinstance(batch_axis, tuple): - in_axes = batch_axis[0] + if isinstance(batch_axes, tuple): + in_axes = batch_axes[0] else: - in_axes = batch_axis + in_axes = batch_axes y_pred = jax.vmap(model, in_axes=(in_axes,))(x) return jnp.mean(jnp.square(y_pred - y)) @@ -108,13 +108,13 @@ def __call__( self, model: PyTree, data: PyTree, - batch_axis: int | None | Sequence[Any] = 0, + batch_axes: int | None | Sequence[Any] = 0, ) -> Scalar: x, y = data - if isinstance(batch_axis, tuple): - in_axes = batch_axis[0] + if isinstance(batch_axes, tuple): + in_axes = batch_axes[0] else: - in_axes = batch_axis + in_axes = batch_axes y_pred = jax.vmap(model, in_axes=(in_axes,))(x) return jnp.mean(jnp.abs(y_pred - y)) diff --git a/klax/_training.py b/klax/_training.py index aa8fab0..b6f9c04 100644 --- a/klax/_training.py +++ b/klax/_training.py @@ -42,7 +42,7 @@ def fit[T: PyTree[Any]]( data: PyTree[Any], *, batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, + batch_axes: PyTree[int | None] = 0, validation_data: PyTree[Any] = None, steps: int = 1000, loss_fn: Loss = mse, @@ -59,7 +59,7 @@ def fit[T: PyTree[Any], H: Callback]( data: PyTree[Any], *, batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, + batch_axes: PyTree[int | None] = 0, validation_data: PyTree[Any] = None, steps: int = 1000, loss_fn: Loss = mse, @@ -75,7 +75,7 @@ def fit[T: PyTree[Any], H: Callback]( data: PyTree[Any], *, batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, + batch_axes: PyTree[int | None] = 0, validation_data: PyTree[Any] = None, steps: int = 1000, loss_fn: Loss = mse, @@ -96,9 +96,9 @@ def fit[T: PyTree[Any], H: Callback]( Most likely you'll want `data` to be a tuple `(x, y)` with model inputs `x` and model outputs `y`. batch_size: The number of examples in a batch. - batch_axis: A `PyTree` denoting, which axis is the batch axis for - arrays in `data`. `batch_axis` must be a prefix of `data`. By - specifying `batch_axis` as a `PyTree` it is possible to specify + batch_axes: A `PyTree` denoting, which axis is the batch axis for + arrays in `data`. `batch_axes` must be a prefix of `data`. By + specifying `batch_axes` as a `PyTree` it is possible to specify different batch axes for different leaves of `data`. (Defaults to `0`, meaning the first axes of arrays in `data` are batch dimensions.) @@ -107,7 +107,7 @@ def fit[T: PyTree[Any], H: Callback]( to None.) steps: Number of gradient updates to apply. (Defaults to 1000.) loss_fn: The loss function with call signature - `(model: PyTree, data: PyTree, batch_axis: int | None | + `(model: PyTree, data: PyTree, batch_axes: int | None | Sequence[Any]) -> float`. (Defaults to `mse`.) optimizer: The optimizer. Any optax gradient transform to calculate the updates for the model. (Defaults to optax.adam(1e-3).) @@ -140,20 +140,20 @@ def fit[T: PyTree[Any], H: Callback]( A tuple of the trained model and the loss history. """ - # Braodcast the batch_axis to the data. While this happens again in the - # batch_data, doing it here allows the use of the broadcasted batch_axis in - # the loss function. If `batch_axis` is a prefix of `data`, this ensures + # Braodcast the batch axes to the data. While this happens again in the + # batch_data, doing it here allows the use of the broadcasted batch_axes in + # the loss function. If `batch_axes` is a prefix of `data`, this ensures # that only leafs of type ArrayLike are vmapped. Thus it is possible to - # have data like `(str, array)` ans still use `batch_axis=0` instead of - # `batch_axis=(None, 0)`. - batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) + # have data like `(str, array)` ans still use `batch_axes=0` instead of + # `batch_axes=(None, 0)`. + batch_axes, dataset_size = broadcast_and_get_size(data, batch_axes) # Define a function to calculate the loss. This is jit compiled to speed up # the loss evaluation for the loss history. @eqx.filter_jit def combined_loss(model, batch): model = unwrap(model) - return loss_fn(model, batch, batch_axis=batch_axis) + return loss_fn(model, batch, batch_axes=batch_axes) # This partitioned loss function is required within the make_step function, # because the optax.lbgfs GradientTransformation required the loss function @@ -235,7 +235,7 @@ def make_step(batch, flat_model, optimizer, flat_opt_state): # Loop over all training steps for step, batch in zip( range(1, steps + 1), - batcher(data, batch_size, batch_axis, key=key), + batcher(data, batch_size, batch_axes, key=key), ): flat_model, flat_opt_state = make_step( batch, flat_model, optimizer, flat_opt_state diff --git a/tests/test_datahandler.py b/tests/test_datahandler.py index 26d90e6..d3471dd 100644 --- a/tests/test_datahandler.py +++ b/tests/test_datahandler.py @@ -48,8 +48,8 @@ def test_batch_data(getkey): # Batch mask x = jrandom.uniform(getkey(), (10,)) data = (x, (x, x)) - batch_axis = (0, (None, 0)) - generator = batch_data(data, 2, batch_axis, key=getkey()) + batch_axes = (0, (None, 0)) + generator = batch_data(data, 2, batch_axes, key=getkey()) assert next(generator)[0].shape[0] == 2 assert next(generator)[1][0].shape[0] == 10 assert next(generator)[1][1].shape[0] == 2 @@ -57,8 +57,8 @@ def test_batch_data(getkey): # No batch dimensions x = jrandom.uniform(getkey(), (10,)) data = (x,) - batch_axis = None - generator = batch_data(data, batch_axis=batch_axis, key=getkey()) + batch_axes = None + generator = batch_data(data, batch_axes=batch_axes, key=getkey()) assert next(generator) == data # Different batch sizes @@ -90,8 +90,8 @@ def test_split_data(getkey): ], ) proportions = (2, 1, 1) - batch_axis = (0, 1) - subsets = split_data(data, proportions, batch_axis, key=getkey()) + batch_axes = (0, 1) + subsets = split_data(data, proportions, batch_axes, key=getkey()) for s, p in zip(subsets, (0.5, 0.25, 0.25)): assert s[0].shape == (round(p * batch_size), 2) diff --git a/tests/test_training.py b/tests/test_training.py index 0285307..0b9581e 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -50,7 +50,7 @@ def __call__(self, x): model, _ = klax.fit( model, ((b, x), y), - batch_axis=0, # Test automatic batch axis braodcasting to data + batch_axes=0, # Test automatic batch axis braodcasting to data optimizer=optax.adam(1.0), key=getkey(), ) From 6b097574a6952001977e9efb508e3d3477cad13d Mon Sep 17 00:00:00 2001 From: jaosch Date: Fri, 17 Oct 2025 10:55:39 +0200 Subject: [PATCH 5/6] Renamed `data` to `batch` in `Loss` API --- klax/_losses.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/klax/_losses.py b/klax/_losses.py index 906025f..006dfa2 100644 --- a/klax/_losses.py +++ b/klax/_losses.py @@ -53,14 +53,14 @@ class Loss(Protocol): def __call__( self, model: PyTree, - data: PyTree, + batch: PyTree, batch_axes: int | None | Sequence[Any], ) -> Scalar: """Abstract method to compute the loss for a given model and data. Args: model: The model parameters or structure to evaluate the loss. - data: The input data or structure used for loss computation. + batch: A batch of input data used for loss computation. batch_axes: Specifies the axis or axes corresponding to the batch dimension in the data. Can be an integer, None, or a sequence of values. @@ -73,7 +73,7 @@ def __call__( class MSE(Loss): - """Mean squared error for a tuple of data `(x, y)`. + """Mean squared error for a batch of data of the form `(x, y)`. The inputs `x` and the outputs `y` are expected to have the same batch axis and equal length along that axis. @@ -82,10 +82,10 @@ class MSE(Loss): def __call__( self, model: PyTree, - data: PyTree, + batch: PyTree, batch_axes: int | None | Sequence[Any] = 0, ) -> Scalar: - x, y = data + x, y = batch if isinstance(batch_axes, tuple): in_axes = batch_axes[0] else: @@ -98,7 +98,7 @@ def __call__( class MAE(Loss): - """Mean absolute error for a tuple of data `(x, y)`. + """Mean absolute error for a batch of data of the form `(x, y)`. The inputs `x` and the outputs `y` are expected to have the same batch axis and equal length along that axis. @@ -107,10 +107,10 @@ class MAE(Loss): def __call__( self, model: PyTree, - data: PyTree, + batch: PyTree, batch_axes: int | None | Sequence[Any] = 0, ) -> Scalar: - x, y = data + x, y = batch if isinstance(batch_axes, tuple): in_axes = batch_axes[0] else: From c1a827f99ed1764635aab10b4bb243169d0f3006 Mon Sep 17 00:00:00 2001 From: jaosch Date: Fri, 17 Oct 2025 10:56:58 +0200 Subject: [PATCH 6/6] Removed paramax from list of dependencies --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 44c3bff..944e0e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,6 @@ dependencies = [ "equinox>=0.12.2", "jax>=0.6.0", "optax>=0.2.4", - "paramax>=0.0.3", ] classifiers = [ "Development Status :: 3 - Alpha",