Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions klax/_datahandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,40 @@


def broadcast_and_get_size(
data: PyTree[Any], batch_axis: PyTree[int | None]
data: PyTree[Any], batch_axes: PyTree[int | None]
) -> tuple[PyTree[int | None], int]:
"""Broadcast `batch_axis` to the same structure as `data` and get size.
"""Broadcast `batch_axes` to the same structure as `data` and get size.

Args:
data: PyTree of data.
batch_axis: PyTree of the batch axis indices. `None` is used to
batch_axes: PyTree of batch axis indices. `None` is used to
indicate that the corresponding leaf or subtree in data does not
have a batch axis. `batch_axis` must have the same structure as
have a batch axis. `batch_axes` must have the same structure as
`data` or have `data` as a prefix.

Raises:
ValueError: If `batch_axis` is not a prefix of `data`.
ValueError: If `batch_axes` is not a prefix of `data`.
ValueError: If not all batch axes have the same size.

Returns:
Tuple of the broadcasted `batch_axis` and the `dataset_size`.
Tuple of the broadcasted `batch_axes` and the `dataset_size`.

"""
try:
batch_axis = jax.tree.map(
batch_axes = jax.tree.map(
lambda a, d: jax.tree.map(eqx.if_array(a), d),
batch_axis,
batch_axes,
data,
is_leaf=lambda x: x is None,
)
except ValueError as e:
raise ValueError(
f"batch_axis must be a prefix of data. Original message: {e}"
f"batch_axes must be a prefix of data. Original message: {e}"
)

dataset_sizes = jax.tree.map(
lambda a, d: None if a is None else d.shape[a],
batch_axis,
batch_axes,
data,
is_leaf=lambda x: x is None,
)
Expand All @@ -74,7 +74,7 @@ def broadcast_and_get_size(
raise ValueError("All batched arrays must have equal batch sizes.")
dataset_size = dataset_sizes[0]

return batch_axis, dataset_size
return batch_axes, dataset_size


@typing.runtime_checkable
Expand All @@ -83,7 +83,7 @@ def __call__(
self,
data: PyTree[Any],
batch_size: int,
batch_axis: PyTree[int | None],
batch_axes: PyTree[int | None],
*,
key: PRNGKeyArray,
) -> Generator[PyTree[Any], None, None]:
Expand All @@ -93,7 +93,7 @@ def __call__(
def batch_data(
data: PyTree[Any],
batch_size: int = 32,
batch_axis: PyTree[int | None] = 0,
batch_axes: PyTree[int | None] = 0,
convert_to_numpy: bool = True,
*,
key: PRNGKeyArray,
Expand Down Expand Up @@ -134,9 +134,9 @@ def batch_data(
data: The data that shall be batched. It can be any `PyTree` with
`ArrayLike` leaves.
batch_size: The number of examples in a batch.
batch_axis: PyTree of the batch axis indices. `None` is used to
batch_axes: PyTree of batch axis indices. `None` is used to
indicate that the corresponding leaf or subtree in data does not
have a batch axis. `batch_axis` must have the same structure as
have a batch axis. `batch_axes` must have the same structure as
`data` or have `data` as a prefix.
(Defaults to 0, meaning all leaves in `data` are
batched along their first dimension.)
Expand All @@ -159,7 +159,7 @@ def batch_data(
obtained batches will have dataset size.

"""
batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis)
batch_axes, dataset_size = broadcast_and_get_size(data, batch_axes)

# Convert to Numpy arrays. Numpy's slicing is much faster than JAX's, so
# for fast model training steps this actually makes a huge difference!
Expand All @@ -168,7 +168,7 @@ def batch_data(
data = jax.tree.map(
lambda x, a: x if a is None else np.array(x),
data,
batch_axis,
batch_axes,
is_leaf=lambda x: x is None,
)

Expand All @@ -184,7 +184,7 @@ def batch_data(
batch_perm = perm[start:end]
yield jax.tree.map(
lambda a, x: x if a is None else x[batch_perm],
batch_axis,
batch_axes,
data,
is_leaf=lambda x: x is None,
)
Expand All @@ -195,14 +195,14 @@ def batch_data(
def split_data(
data: PyTree[Any],
proportions: Sequence[int | float],
batch_axis: PyTree[int | None] = 0,
batch_axes: PyTree[int | None] = 0,
*,
key: PRNGKeyArray,
) -> tuple[PyTree[Any], ...]:
"""Split a `PyTree` of data into multiply randomly drawn subsets.

This function is useful for splitting into training and test datasets. The
axis of the split if controlled by the `batch_axis` argument, which
axis of the split if controlled by the `batch_axes` argument, which
specifies the batch axis for each leaf in `data`.

Example:
Expand Down Expand Up @@ -230,9 +230,9 @@ def split_data(
proportions: Proportions of the split that will be applied to the data,
e.g., `(80, 20)` for a 80% to 20% split. The proportions must be
non-negative.
batch_axis: PyTree of the batch axis indices. `None` is used to
batch_axes: PyTree of batch axis indices. `None` is used to
indicate that the corresponding leaf or subtree in data does not
have a batch axis. `batch_axis` must have the same structure as
have a batch axis. `batch_axes` must have the same structure as
`data` or have `data` as a prefix. (Defaults to 0)
key: A `jax.random.PRNGKey` used to provide randomness to the split.
(Keyword only argument.)
Expand All @@ -248,7 +248,7 @@ def split_data(
raise ValueError("Proportions must be non-negative.")
props = props / jnp.sum(props)

batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis)
batch_axes, dataset_size = broadcast_and_get_size(data, batch_axes)

indices = jnp.arange(dataset_size)
perm = jr.permutation(key, indices)
Expand All @@ -266,7 +266,7 @@ def get_subset(section):
lambda a, d: d
if a is None
else jnp.take(d, section, axis=a, unique_indices=True),
batch_axis,
batch_axes,
data,
is_leaf=lambda x: x is None,
)
Expand Down
44 changes: 22 additions & 22 deletions klax/_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ class Loss(Protocol):
be implemented as follows:

```python
>>> def mse(model, data, batch_axis=0):
>>> def mse(model, data, batch_axes=0):
... x, y = data
... if isinstance(batch_axis, tuple):
... in_axes = batch_axis[0]
... if isinstance(batch_axes, tuple):
... in_axes = batch_axes[0]
... else:
... in_axes = batch_axis
... in_axes = batch_axes
... y_pred = jax.vmap(model, in_axes=(in_axes,))(x)
... return jnp.mean(jnp.square(y_pred - y))
```
Expand All @@ -53,15 +53,15 @@ class Loss(Protocol):
def __call__(
self,
model: PyTree,
data: PyTree,
batch_axis: int | None | Sequence[Any],
batch: PyTree,
batch_axes: int | None | Sequence[Any],
) -> Scalar:
"""Abstract method to compute the loss for a given model and data.

Args:
model: The model parameters or structure to evaluate the loss.
data: The input data or structure used for loss computation.
batch_axis: Specifies the axis or axes corresponding to the batch
batch: A batch of input data used for loss computation.
batch_axes: Specifies the axis or axes corresponding to the batch
dimension in the data. Can be an integer, None, or a sequence
of values.

Expand All @@ -73,7 +73,7 @@ def __call__(


class MSE(Loss):
"""Mean squared error for a tuple of data `(x, y)`.
"""Mean squared error for a batch of data of the form `(x, y)`.

The inputs `x` and the outputs `y` are expected to have the same batch axis
and equal length along that axis.
Expand All @@ -82,14 +82,14 @@ class MSE(Loss):
def __call__(
self,
model: PyTree,
data: PyTree,
batch_axis: int | None | Sequence[Any] = 0,
batch: PyTree,
batch_axes: int | None | Sequence[Any] = 0,
) -> Scalar:
x, y = data
if isinstance(batch_axis, tuple):
in_axes = batch_axis[0]
x, y = batch
if isinstance(batch_axes, tuple):
in_axes = batch_axes[0]
else:
in_axes = batch_axis
in_axes = batch_axes
y_pred = jax.vmap(model, in_axes=(in_axes,))(x)
return jnp.mean(jnp.square(y_pred - y))

Expand All @@ -98,7 +98,7 @@ def __call__(


class MAE(Loss):
"""Mean absolute error for a tuple of data `(x, y)`.
"""Mean absolute error for a batch of data of the form `(x, y)`.

The inputs `x` and the outputs `y` are expected to have the same batch axis
and equal length along that axis.
Expand All @@ -107,14 +107,14 @@ class MAE(Loss):
def __call__(
self,
model: PyTree,
data: PyTree,
batch_axis: int | None | Sequence[Any] = 0,
batch: PyTree,
batch_axes: int | None | Sequence[Any] = 0,
) -> Scalar:
x, y = data
if isinstance(batch_axis, tuple):
in_axes = batch_axis[0]
x, y = batch
if isinstance(batch_axes, tuple):
in_axes = batch_axes[0]
else:
in_axes = batch_axis
in_axes = batch_axes
y_pred = jax.vmap(model, in_axes=(in_axes,))(x)
return jnp.mean(jnp.abs(y_pred - y))

Expand Down
36 changes: 18 additions & 18 deletions klax/_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@


@overload
def fit[T: eqx.Module](
def fit[T: PyTree[Any]](
model: T,
data: PyTree[Any],
*,
batch_size: int = 32,
batch_axis: PyTree[int | None] = 0,
batch_axes: PyTree[int | None] = 0,
validation_data: PyTree[Any] = None,
steps: int = 1000,
loss_fn: Loss = mse,
Expand All @@ -54,12 +54,12 @@ def fit[T: eqx.Module](
key: PRNGKeyArray,
) -> tuple[T, HistoryCallback]: ...
@overload
def fit[T: eqx.Module, H: Callback](
def fit[T: PyTree[Any], H: Callback](
model: T,
data: PyTree[Any],
*,
batch_size: int = 32,
batch_axis: PyTree[int | None] = 0,
batch_axes: PyTree[int | None] = 0,
validation_data: PyTree[Any] = None,
steps: int = 1000,
loss_fn: Loss = mse,
Expand All @@ -70,12 +70,12 @@ def fit[T: eqx.Module, H: Callback](
callbacks: Iterable[Callback] | None = None,
key: PRNGKeyArray,
) -> tuple[T, H]: ...
def fit[T: eqx.Module, H: Callback](
def fit[T: PyTree[Any], H: Callback](
model: T,
data: PyTree[Any],
*,
batch_size: int = 32,
batch_axis: PyTree[int | None] = 0,
batch_axes: PyTree[int | None] = 0,
validation_data: PyTree[Any] = None,
steps: int = 1000,
loss_fn: Loss = mse,
Expand All @@ -96,9 +96,9 @@ def fit[T: eqx.Module, H: Callback](
Most likely you'll want `data` to be a tuple `(x, y)` with model
inputs `x` and model outputs `y`.
batch_size: The number of examples in a batch.
batch_axis: A `PyTree` denoting, which axis is the batch axis for
arrays in `data`. `batch_axis` must be a prefix of `data`. By
specifying `batch_axis` as a `PyTree` it is possible to specify
batch_axes: A `PyTree` denoting, which axis is the batch axis for
arrays in `data`. `batch_axes` must be a prefix of `data`. By
specifying `batch_axes` as a `PyTree` it is possible to specify
different batch axes for different leaves of `data`. (Defaults to
`0`, meaning the first axes of arrays in `data` are batch
dimensions.)
Expand All @@ -107,7 +107,7 @@ def fit[T: eqx.Module, H: Callback](
to None.)
steps: Number of gradient updates to apply. (Defaults to 1000.)
loss_fn: The loss function with call signature
`(model: PyTree, data: PyTree, batch_axis: int | None |
`(model: PyTree, data: PyTree, batch_axes: int | None |
Sequence[Any]) -> float`. (Defaults to `mse`.)
optimizer: The optimizer. Any optax gradient transform to calculate
the updates for the model. (Defaults to optax.adam(1e-3).)
Expand Down Expand Up @@ -140,20 +140,20 @@ def fit[T: eqx.Module, H: Callback](
A tuple of the trained model and the loss history.

"""
# Braodcast the batch_axis to the data. While this happens again in the
# batch_data, doing it here allows the use of the broadcasted batch_axis in
# the loss function. If `batch_axis` is a prefix of `data`, this ensures
# Braodcast the batch axes to the data. While this happens again in the
# batch_data, doing it here allows the use of the broadcasted batch_axes in
# the loss function. If `batch_axes` is a prefix of `data`, this ensures
# that only leafs of type ArrayLike are vmapped. Thus it is possible to
# have data like `(str, array)` ans still use `batch_axis=0` instead of
# `batch_axis=(None, 0)`.
batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis)
# have data like `(str, array)` ans still use `batch_axes=0` instead of
# `batch_axes=(None, 0)`.
batch_axes, dataset_size = broadcast_and_get_size(data, batch_axes)

# Define a function to calculate the loss. This is jit compiled to speed up
# the loss evaluation for the loss history.
@eqx.filter_jit
def combined_loss(model, batch):
model = unwrap(model)
return loss_fn(model, batch, batch_axis=batch_axis)
return loss_fn(model, batch, batch_axes=batch_axes)

# This partitioned loss function is required within the make_step function,
# because the optax.lbgfs GradientTransformation required the loss function
Expand Down Expand Up @@ -235,7 +235,7 @@ def make_step(batch, flat_model, optimizer, flat_opt_state):
# Loop over all training steps
for step, batch in zip(
range(1, steps + 1),
batcher(data, batch_size, batch_axis, key=key),
batcher(data, batch_size, batch_axes, key=key),
):
flat_model, flat_opt_state = make_step(
batch, flat_model, optimizer, flat_opt_state
Expand Down
4 changes: 2 additions & 2 deletions klax/_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def apply(self) -> Self:
pass


def apply(tree: PyTree):
def apply(tree: PyTree) -> PyTree:
"""Map across a PyTree and apply all [Constraints][klax.Constraint].

This leaves all other nodes unchanged.
Expand Down Expand Up @@ -437,7 +437,7 @@ def apply(self) -> Self:
# ===----------------------------------------------------------------------===#


def finalize(tree: PyTree):
def finalize(tree: PyTree) -> PyTree:
"""Make a model containing [Constraints][klax.Constraint] callable.

This function combined that functionalities of [`klax.apply`][] and
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ dependencies = [
"equinox>=0.12.2",
"jax>=0.6.0",
"optax>=0.2.4",
"paramax>=0.0.3",
]
classifiers = [
"Development Status :: 3 - Alpha",
Expand Down
Loading
Loading