diff --git a/klax/__init__.py b/klax/__init__.py index 385282b..6cd54ae 100644 --- a/klax/__init__.py +++ b/klax/__init__.py @@ -16,12 +16,15 @@ from ._callbacks import ( Callback as Callback, ) -from ._callbacks import ( - CallbackArgs as CallbackArgs, -) from ._callbacks import ( HistoryCallback as HistoryCallback, ) +from ._context import ( + EvaluationContext, + TimingInfo, + TrainingContext, + TrainingState, +) from ._datahandler import ( BatchGenerator as BatchGenerator, ) @@ -32,13 +35,7 @@ split_data as split_data, ) from ._losses import ( - MAE as MAE, -) -from ._losses import ( - MSE as MSE, -) -from ._losses import ( - Loss as Loss, + LossFactory as LossFactory, ) from ._losses import ( mae as mae, diff --git a/klax/_callbacks.py b/klax/_callbacks.py index c4df84a..d1acd5a 100644 --- a/klax/_callbacks.py +++ b/klax/_callbacks.py @@ -16,7 +16,7 @@ import importlib import pickle import time -from abc import ABC +from abc import ABC, abstractmethod from collections.abc import Callable from pathlib import Path from typing import Any, Self @@ -24,132 +24,7 @@ import jax from jaxtyping import PyTree, PyTreeDef, Scalar - -class CallbackArgs: - """A callback argument designed to work in conjunction with [`klax.fit`][]. - - This class should not be instantiated directly. An instance of this class - is passed to every callback object in the fit function. When writing a - custom callback, use the properties of this class to access the current - model, optimizer state, training data, and validation data during training. - - This class implements cached and lazy-evaluated values via property - methods. This means that properties like ``loss`` are only calculated if - they are used and are stored such that they are not calculated multiple - times. - """ - - step: int #: Current step-count of the training. - time_on_last_update: float #: Global time of the last :meth:`update` call. - data: PyTree #: PyTree of the training data. - val_data: PyTree | None #: PyTree of the validation data. - _treedef_model: PyTreeDef - _flat_model: list - _treedef_opt_state: PyTreeDef - _flat_opt_state: list - _cache: dict = {} - _get_loss: Callable[[PyTree, PyTree], Scalar] - _start_time: float - - def __init__( - self, - get_loss: Callable[[PyTree, PyTree], Scalar], - treedef_model: PyTreeDef, - treedef_opt_state: PyTreeDef, - data: PyTree, - val_data: PyTree | None = None, - ): - """Initialize the callback arguments object. - - Args: - get_loss: Function that takes a model and a batch of data and - returns the loss. - treedef_model: Tree structure of the model. - treedef_opt_state: Tree structure of the :py:mod:`optax` optimizer. - data: PyTree of the training data. - val_data: PyTree of the validation data. If None, no validation - loss is calculated and the property :py:attr:`val_loss` will - return None. - - """ - self.data = data - self.val_data = val_data - self._get_loss = get_loss - self._treedef_model = treedef_model - self._treedef_opt_state = treedef_opt_state - - def update(self, flat_model: PyTree, flat_opt_state: PyTree, step: int): - """Update the object with the current model and optimizer state. - - This method is called repeatedly in [`klax.fit`][]. - - Args: - flat_model: Flattened PyTree of the model. - flat_opt_state: Flattened PyTree of the `optax` - optimizer. - step: Current step-count of the training. - - """ - self._flat_model = flat_model - self._flat_opt_state = flat_opt_state - self.step = step - self.time_on_last_update = time.time() - - # Clear cache - self._cache = {} - - @staticmethod - def _lazy_evaluated_and_cached(fun: Callable[[Any], Any]) -> property: - """Turn a public method into a property. - - The return value of ``fun`` is stored in the ``_cache`` dictionary of - the current object using the function name as key. If the name is - already in ``_cache`` then the cached value is simply returned, - without evaluating ``fun``. - - Args: - fun: Method to wrap. - - Returns: - Wrapped method as a property. - - """ - attr_name = fun.__name__ - - def wrapper(self: Self): - if attr_name not in self._cache: - self._cache.setdefault(attr_name, fun(self)) - return self._cache.get(attr_name) - - wrapper.__doc__ = fun.__doc__ - - return property(wrapper) - - @_lazy_evaluated_and_cached - def model(self): - """Lazy-evaluated and cached model.""" - return jax.tree_util.tree_unflatten( - self._treedef_model, self._flat_model - ) - - @_lazy_evaluated_and_cached - def opt_state(self): - """Lazy-evaluated and cached optimizer state.""" - return jax.tree_util.tree_unflatten( - self._treedef_opt_state, self._flat_opt_state - ) - - @_lazy_evaluated_and_cached - def loss(self): - """Lazy-evaluated and cached training loss.""" - return self._get_loss(self.model, self.data) - - @_lazy_evaluated_and_cached - def val_loss(self) -> Scalar | None: - """Lazy-evaluated and cached validation loss.""" - if self.val_data is None: - return None - return self._get_loss(self.model, self.val_data) +from ._context import TrainingContext class Callback(ABC): @@ -158,19 +33,168 @@ class Callback(ABC): Inherit from this class to create a custom callback. """ - def __call__(self, cbargs: CallbackArgs) -> bool | None: + @abstractmethod + def __call__(self, ctx: TrainingContext) -> bool | None: """Call after each step during training.""" pass - def on_training_end(self, cbargs: CallbackArgs) -> None: + @abstractmethod + def on_training_end(self, ctx: TrainingContext) -> None: """Call when training ends.""" pass - def on_training_start(self, cbargs: CallbackArgs) -> None: + @abstractmethod + def on_training_start(self, ctx: TrainingContext) -> None: """Call when training starts.""" pass +# class CallbackArgs: +# """A callback argument designed to work in conjunction with [`klax.fit`][]. + +# This class should not be instantiated directly. An instance of this class +# is passed to every callback object in the fit function. When writing a +# custom callback, use the properties of this class to access the current +# model, optimizer state, training data, and validation data during training. + +# This class implements cached and lazy-evaluated values via property +# methods. This means that properties like ``loss`` are only calculated if +# they are used and are stored such that they are not calculated multiple +# times. +# """ + +# step: int #: Current step-count of the training. +# time_on_last_update: float #: Global time of the last :meth:`update` call. +# data: PyTree #: PyTree of the training data. +# val_data: PyTree | None #: PyTree of the validation data. +# _treedef_model: PyTreeDef +# _flat_model: list +# _treedef_opt_state: PyTreeDef +# _flat_opt_state: list +# _cache: dict = {} +# _get_loss: Callable[[PyTree, PyTree], Scalar] +# _start_time: float + +# def __init__( +# self, +# get_loss: Callable[[PyTree, PyTree], Scalar], +# treedef_model: PyTreeDef, +# treedef_opt_state: PyTreeDef, +# data: PyTree, +# val_data: PyTree | None = None, +# ): +# """Initialize the callback arguments object. + +# Args: +# get_loss: Function that takes a model and a batch of data and +# returns the loss. +# treedef_model: Tree structure of the model. +# treedef_opt_state: Tree structure of the :py:mod:`optax` optimizer. +# data: PyTree of the training data. +# val_data: PyTree of the validation data. If None, no validation +# loss is calculated and the property :py:attr:`val_loss` will +# return None. + +# """ +# self.data = data +# self.val_data = val_data +# self._get_loss = get_loss +# self._treedef_model = treedef_model +# self._treedef_opt_state = treedef_opt_state + +# def update(self, flat_model: PyTree, flat_opt_state: PyTree, step: int): +# """Update the object with the current model and optimizer state. + +# This method is called repeatedly in [`klax.fit`][]. + +# Args: +# flat_model: Flattened PyTree of the model. +# flat_opt_state: Flattened PyTree of the `optax` +# optimizer. +# step: Current step-count of the training. + +# """ +# self._flat_model = flat_model +# self._flat_opt_state = flat_opt_state +# self.step = step +# self.time_on_last_update = time.time() + +# # Clear cache +# self._cache = {} + +# @staticmethod +# def _lazy_evaluated_and_cached(fun: Callable[[Any], Any]) -> property: +# """Turn a public method into a property. + +# The return value of ``fun`` is stored in the ``_cache`` dictionary of +# the current object using the function name as key. If the name is +# already in ``_cache`` then the cached value is simply returned, +# without evaluating ``fun``. + +# Args: +# fun: Method to wrap. + +# Returns: +# Wrapped method as a property. + +# """ +# attr_name = fun.__name__ + +# def wrapper(self: Self): +# if attr_name not in self._cache: +# self._cache.setdefault(attr_name, fun(self)) +# return self._cache.get(attr_name) + +# wrapper.__doc__ = fun.__doc__ + +# return property(wrapper) + +# @_lazy_evaluated_and_cached +# def model(self): +# """Lazy-evaluated and cached model.""" +# return jax.tree_util.tree_unflatten( +# self._treedef_model, self._flat_model +# ) + +# @_lazy_evaluated_and_cached +# def opt_state(self): +# """Lazy-evaluated and cached optimizer state.""" +# return jax.tree_util.tree_unflatten( +# self._treedef_opt_state, self._flat_opt_state +# ) + +# @_lazy_evaluated_and_cached +# def loss(self): +# """Lazy-evaluated and cached training loss.""" +# return self._get_loss(self.model, self.data) + +# @_lazy_evaluated_and_cached +# def val_loss(self) -> Scalar | None: +# """Lazy-evaluated and cached validation loss.""" +# if self.val_data is None: +# return None +# return self._get_loss(self.model, self.val_data) + + +# class Callback(ABC): +# """An abstract callback. + +# Inherit from this class to create a custom callback. +# """ + +# def __call__(self, cbargs: CallbackArgs) -> bool | None: +# """Call after each step during training.""" +# pass + +# def on_training_end(self, cbargs: CallbackArgs) -> None: +# """Call when training ends.""" +# pass + +# def on_training_start(self, cbargs: CallbackArgs) -> None: +# """Call when training starts.""" +# pass + + class HistoryCallback(Callback): """Default callback for logging a training process. @@ -181,8 +205,6 @@ class HistoryCallback(Callback): steps: list #: List of steps at which the losses were recorded. loss: list val_loss: list - last_start_time: float # start time of the last training - last_end_time: float # End time of the last training training_time: float = 0 # Total training time of all trainings verbose: bool step_offset: int = 0 # Potential offset due to previous trainings @@ -211,44 +233,45 @@ def __repr__(self): f"verbose={self.verbose})" ) - def __call__(self, cbargs: CallbackArgs): + def __call__(self, ctx: TrainingContext): """Record the losses and step count. Called at each step during training. """ - if cbargs.step % self.log_every == 0: - self.steps.append(self.step_offset + cbargs.step) - self.loss.append(cbargs.loss) - self.val_loss.append(cbargs.val_loss) + if ctx.step % self.log_every == 0: + self.steps.append(self.step_offset + ctx.step) + self.loss.append(ctx.loss) + self.val_loss.append(ctx.val_loss) # Print message if self.verbose: - message = f"Step: {cbargs.step}, Loss: {cbargs.loss:.3e}" - if cbargs.val_data is not None: - message += f", Validation loss: {cbargs.val_loss:.3e}" + message = f"Step: {ctx.step}, Loss: {ctx.loss:.3e}" + if ctx.val_data is not None: + message += f", Validation loss: {ctx.val_loss:.3e}" print(message) - def on_training_start(self, cbargs: CallbackArgs): + def on_training_start(self, ctx: TrainingContext): """Initialize the training start time. Called at beginning of training. """ - self.last_start_time = cbargs.time_on_last_update + assert ctx.start_time is not None + self.last_start_time = ctx.start_time if self.steps: # If there are already steps, we assume that this is a continuation # of a training. self.step_offset = self.steps[-1] else: - self(cbargs) + self(ctx) - def on_training_end(self, cbargs: CallbackArgs): + def on_training_end(self, ctx: TrainingContext): """Record the training end time and the last optimizer state. Called at end of training. """ - self.last_end_time = cbargs.time_on_last_update - self.training_time += self.last_end_time - self.last_start_time - self.last_opt_state = cbargs.opt_state + assert ctx.total_time is not None + self.training_time += ctx.total_time + self.last_opt_state = ctx.opt_state if self.verbose: print( f"Training took: { diff --git a/klax/_context.py b/klax/_context.py new file mode 100644 index 0000000..b86f722 --- /dev/null +++ b/klax/_context.py @@ -0,0 +1,235 @@ +import time +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Self + +import equinox as eqx +import jax +from jaxtyping import PyTree, Scalar + +from ._losses import ValueFn +from ._wrappers import apply + + +class TrainingState: + _flat_model: PyTree + _flat_opt_state: PyTree + _treedef_model: PyTree + _treedef_opt_state: PyTree + _step: int + _cache: dict[str, Any] = {} + + def __init__( + self, model: PyTree, opt_state: PyTree = None, initial_step: int = 0 + ): + # Apply the Constraint in the model to ensure apply-constrains are met + # initially + model = apply(model) + + # Use the unflatten trick to speed up training, + # see https://docs.kidger.site/equinox/tricks/ + flat_model, treedef_model = jax.tree.flatten(model) + flat_opt_state, treedef_opt_state = jax.tree.flatten(opt_state) + + self._flat_model = flat_model + self._flat_opt_state = flat_opt_state + self._treedef_model = treedef_model + self._treedef_opt_state = treedef_opt_state + self._step = initial_step + + @staticmethod + def _lazy_chached_property(fun: Callable) -> property: + """Turn a public method into a lazily evaluated property. + + The return value of ``fun`` is stored in the ``_cache`` dictionary of + the current object using the function name as key. If the name is + already in ``_cache`` then the cached value is simply returned, + without evaluating ``fun``. + + Args: + fun: Method to wrap. + + Returns: + Wrapped method as a property. + + """ + attr_name = fun.__name__ + + def wrapper(self: Self): + if attr_name not in self._cache: + self._cache.setdefault(attr_name, fun(self)) + return self._cache.get(attr_name) + + wrapper.__doc__ = fun.__doc__ + + return property(wrapper) + + @property + def flat_model(self) -> PyTree: + return self._flat_model + + @property + def flat_opt_state(self) -> PyTree: + return self._flat_opt_state + + @property + def step(self) -> int: + return self._step + + @_lazy_chached_property + def model(self) -> PyTree: + return jax.tree_util.tree_unflatten( + self._treedef_model, self._flat_model + ) + + @_lazy_chached_property + def opt_state(self) -> PyTree: + return jax.tree_util.tree_unflatten( + self._treedef_opt_state, self._flat_opt_state + ) + + @property + def treedef_model(self) -> PyTree: + return self._treedef_model + + @property + def treedef_opt_state(self) -> PyTree: + return self._treedef_opt_state + + def update(self, flat_model: PyTree, flat_opt_state: PyTree): + self._flat_model = flat_model + self._flat_opt_state = flat_opt_state + self._step += 1 + + # Clear cache + self._cache.clear() + + +@dataclass +class EvaluationContext: + value_fn: ValueFn + data: PyTree[Any] + val_data: PyTree[Any] | None = None + _cached_step: int | None = None + _cache: dict[str, Any] = field(default_factory=dict) + + def _ensure_step(self, state: TrainingState): + if self._cached_step != state.step: + self._cache.clear() + self._cached_step = state.step + + @staticmethod + @eqx.filter_jit + def _loss_impl(value_fn: ValueFn, model: PyTree, batch: PyTree[Any]): + return value_fn(model, batch) + + def loss(self, state: TrainingState) -> Scalar: + self._ensure_step(state) + if "loss" not in self._cache: + self._cache["loss"] = self._loss_impl( + self.value_fn, state.model, self.data + ) + return self._cache["loss"] + + def val_loss(self, state: TrainingState) -> Scalar | None: + self._ensure_step(state) + if self.val_data is None: + return None + + if "val_loss" not in self._cache: + self._cache["val_loss"] = self._loss_impl( + self.value_fn, state.model, self.val_data + ) + return self._cache["val_loss"] + + +@dataclass +class TimingInfo: + start_time: float | None = None + total_time: float = 0.0 + time_of_last_update: float | None = None + + def update(self): + time_of_last_update = time.time() + if self.start_time is None: + self.start_time = time_of_last_update + else: + self.total_time = time_of_last_update - self.start_time + self.time_of_last_update = time_of_last_update + + +@dataclass +class TrainingContext: + _state: TrainingState + _evaluator: EvaluationContext + _timer: TimingInfo + + def __init__( + self, + state: TrainingState, + evaluator: EvaluationContext, + timing: TimingInfo, + ): + self._state = state + self._evaluator = evaluator + self._timer = timing + + def update(self, flat_model: PyTree, flat_opt_state: PyTree) -> None: + self._state.update(flat_model, flat_opt_state) + self._timer.update() + + @property + def flat_opt_state(self) -> PyTree: + return self._state.flat_opt_state + + @property + def flat_model(self) -> PyTree: + return self._state.flat_model + + @property + def model(self) -> PyTree: + return self._state.model + + @property + def treedef_model(self) -> PyTree: + return self._state.treedef_model + + @property + def opt_state(self) -> PyTree: + return self._state.opt_state + + @property + def treedef_opt_state(self) -> PyTree: + return self._state.treedef_opt_state + + @property + def step(self) -> int: + return self._state.step + + @property + def loss(self) -> Scalar: + return self._evaluator.loss(self._state) + + @property + def val_loss(self) -> Scalar | None: + return self._evaluator.val_loss(self._state) + + @property + def data(self) -> PyTree: + return self._evaluator.data + + @property + def val_data(self) -> PyTree: + return self._evaluator.val_data + + @property + def time_of_last_update(self) -> float | None: + return self._timer.time_of_last_update + + @property + def start_time(self) -> float | None: + return self._timer.start_time + + @property + def total_time(self) -> float | None: + return self._timer.total_time diff --git a/klax/_datahandler.py b/klax/_datahandler.py index a9aa536..f446ec9 100644 --- a/klax/_datahandler.py +++ b/klax/_datahandler.py @@ -16,6 +16,7 @@ import typing import warnings +from abc import abstractmethod from collections.abc import Generator, Sequence from typing import Any, Protocol @@ -26,6 +27,10 @@ import numpy as np from jaxtyping import PRNGKeyArray, PyTree +from ._context import TrainingContext + +BatchGenerator = Generator[PyTree[Any], TrainingContext, None] + def broadcast_and_get_size( data: PyTree[Any], batch_axis: PyTree[int | None] @@ -78,87 +83,28 @@ def broadcast_and_get_size( @typing.runtime_checkable -class BatchGenerator(Protocol): +class Batcher(Protocol): + @abstractmethod def __call__( self, data: PyTree[Any], batch_size: int, - batch_axis: PyTree[int | None], + batch_axis: int, *, key: PRNGKeyArray, - ) -> Generator[PyTree[Any], None, None]: - raise NotImplementedError + ) -> BatchGenerator: + pass def batch_data( data: PyTree[Any], - batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, + batch_size: int, + batch_axis: int, convert_to_numpy: bool = True, *, - key: PRNGKeyArray, -) -> Generator[PyTree[Any], None, None]: - """Create a `Generator` that draws subsets of data without replacement. - - The data can be any `PyTree` with `ArrayLike` leaves. If `batch_axis` is - passed, batch axes (including `None` for no batching) can be specified for - every leaf individualy. - A generator is returned that indefinetly yields batches of data with size - `batch_size`. Examples are drawn without replacement until the remaining - dataset is smaller than `batch_size`, at which point the dataset will be - reshuffeld and the process starts over. - - Example: - This is an example for a nested `PyTree`, where the elements x and y - have batch dimension along the first axis. - - ```python - >>> import klax - >>> import jax - >>> import jax.numpy as jnp - >>> - >>> x = jnp.array([1., 2.]) - >>> y = jnp.array([[1.], [2.]]) - >>> data = (x, {"a": 1.0, "b": y}) - >>> batch_axis = (0, {"a": None, "b": 0}) - >>> iter_data = klax.batch_data( - ... data, - ... 32, - ... batch_axis, - ... key=jax.random.key(0) - ... ) - >>> - ``` - - Args: - data: The data that shall be batched. It can be any `PyTree` with - `ArrayLike` leaves. - batch_size: The number of examples in a batch. - batch_axis: PyTree of the batch axis indices. `None` is used to - indicate that the corresponding leaf or subtree in data does not - have a batch axis. `batch_axis` must have the same structure as - `data` or have `data` as a prefix. - (Defaults to 0, meaning all leaves in `data` are - batched along their first dimension.) - convert_to_numpy: If `True`, batched data leafs will be converted to - Numpy arrays before batching. This is useful for performance - reasons, as Numpy's slicing is much faster than JAX's. - key: A `jax.random.PRNGKey` used to provide randomness for batch - generation. (Keyword only argument.) - - Returns: - A `Generator` that yields a random batch of data every time is is - called. - - Yields: - A `PyTree[ArrayLike]` with the same structure as `data`. Where all - batched leaves have `batch_size`. - - Note: - Note that if the size of the dataset is smaller than `batch_size`, the - obtained batches will have dataset size. - - """ + key: PRNGKeyArray, # Only cor compliance with the `Batcher` protocol +) -> BatchGenerator: + """Create a stateful batch generator that uses the step as seed.""" batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) # Convert to Numpy arrays. Numpy's slicing is much faster than JAX's, so @@ -177,6 +123,9 @@ def batch_data( indices = jnp.arange(dataset_size) while True: + # Store the training state as received by the `.send(state)` within + # the training loop. + ctx: TrainingContext = yield perm = jr.permutation(key, indices) (key,) = jr.split(key, 1) # Update key start, end = 0, batch_size diff --git a/klax/_losses.py b/klax/_losses.py index 5e5b7b1..184f9de 100644 --- a/klax/_losses.py +++ b/klax/_losses.py @@ -14,109 +14,65 @@ import typing from abc import abstractmethod -from collections.abc import Sequence from typing import Any, Protocol +import equinox as eqx import jax import jax.numpy as jnp from jaxtyping import PyTree, Scalar @typing.runtime_checkable -class Loss(Protocol): - """An abstract callable loss object. +class ValueFn(Protocol): + @abstractmethod + def __call__(self, model: PyTree, data: PyTree) -> Scalar: + pass - It can be used to build custom losses that can be passed to [`klax.fit`][]. - Example: - A simple custom loss that computes the mean squared error between - the predicted values `y_pred` and true values `y` for in inputs `x` may - be implemented as follows: +@typing.runtime_checkable +class ValueAndGradFn(Protocol): + @abstractmethod + def __call__(self, model: PyTree, data: PyTree) -> tuple[Scalar, PyTree]: + pass - ```python - >>> def mse(model, data, batch_axis=0): - ... x, y = data - ... if isinstance(batch_axis, tuple): - ... in_axes = batch_axis[0] - ... else: - ... in_axes = batch_axis - ... y_pred = jax.vmap(model, in_axes=(in_axes,))(x) - ... return jnp.mean(jnp.square(y_pred - y)) - ``` - Note that, since we a aim to provide a maximum of flexibility the users - have to take care of applying `jax.vmap` to the model themselves. +class LossFactory(Protocol): + @abstractmethod + def __call__(self, batch_axis) -> tuple[ValueFn, ValueAndGradFn]: + pass - """ - @abstractmethod - def __call__( - self, - model: PyTree, - data: PyTree, - batch_axis: int | None | Sequence[Any], - ) -> Scalar: - """Abstract method to compute the loss for a given model and data. - - Args: - model: The model parameters or structure to evaluate the loss. - data: The input data or structure used for loss computation. - batch_axis: Specifies the axis or axes corresponding to the batch - dimension in the data. Can be an integer, None, or a sequence - of values. - - Returns: - Scalar: The computed loss value. - - """ - ... - - -class MSE(Loss): - """Mean squared error for a tuple of data `(x, y)`. - - The inputs `x` and the outputs `y` are expected to have the same batch axis - and equal length along that axis. - """ - - def __call__( - self, - model: PyTree, - data: PyTree, - batch_axis: int | None | Sequence[Any] = 0, - ) -> Scalar: +def mse(batch_axis) -> tuple[ValueFn, ValueAndGradFn]: + def value_fn(model: PyTree, data: PyTree) -> Scalar: x, y = data if isinstance(batch_axis, tuple): in_axes = batch_axis[0] else: in_axes = batch_axis - y_pred = jax.vmap(model, in_axes=(in_axes,))(x) + y_pred = eqx.filter_vmap(model, in_axes=(in_axes,))(x) return jnp.mean(jnp.square(y_pred - y)) + def value_and_grad_fn( + model: PyTree, data: PyTree + ) -> tuple[Scalar, PyTree]: + return eqx.filter_value_and_grad(value_fn)(model, data) -mse = MSE() - - -class MAE(Loss): - """Mean absolute error for a tuple of data `(x, y)`. + return value_fn, value_and_grad_fn - The inputs `x` and the outputs `y` are expected to have the same batch axis - and equal length along that axis. - """ - def __call__( - self, - model: PyTree, - data: PyTree, - batch_axis: int | None | Sequence[Any] = 0, - ) -> Scalar: +def mae(batch_axis) -> tuple[ValueFn, ValueAndGradFn]: + def value_fn(model: PyTree, data: PyTree) -> Scalar: x, y = data if isinstance(batch_axis, tuple): in_axes = batch_axis[0] else: in_axes = batch_axis - y_pred = jax.vmap(model, in_axes=(in_axes,))(x) + y_pred = eqx.filter_vmap(model, in_axes=(in_axes,))(x) return jnp.mean(jnp.abs(y_pred - y)) + def value_and_grad_fn( + model: PyTree, data: PyTree + ) -> tuple[Scalar, PyTree]: + return eqx.filter_value_and_grad(value_fn)(model, data) -mae = MAE() + return value_fn, value_and_grad_fn diff --git a/klax/_training.py b/klax/_training.py index 167348d..3cd9d95 100644 --- a/klax/_training.py +++ b/klax/_training.py @@ -12,182 +12,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Implements a basic training loop.""" +"""Implements a training loop.""" -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, overload +from collections.abc import Callable, Generator, Iterable +from typing import Any, Protocol, Self import equinox as eqx import jax +import jax.numpy as jnp +import numpy as np import optax -from jaxtyping import PRNGKeyArray, PyTree - -from ._callbacks import ( - Callback, - CallbackArgs, - HistoryCallback, +from jaxtyping import PRNGKeyArray, PyTree, Scalar +from optax._src.utils import Sequence + +from ._callbacks import Callback, HistoryCallback +from ._context import ( + EvaluationContext, + TimingInfo, + TrainingContext, + TrainingState, ) -from ._datahandler import ( - BatchGenerator, - batch_data, - broadcast_and_get_size, +from ._datahandler import Batcher, BatchGenerator, batch_data +from ._losses import LossFactory, mse +from ._updaters import ( + Updater, + UpdaterFactory, + optax_transform_update_fn_updater, ) -from ._losses import Loss, mse -from ._wrappers import apply, unwrap - - -@overload -def fit[T: eqx.Module]( - model: T, - data: PyTree[Any], - *, - batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, - validation_data: PyTree[Any] = None, - steps: int = 1000, - loss_fn: Loss = mse, - optimizer: optax.GradientTransformation = optax.adam(1e-3), - init_opt_state: PyTree[Any] = None, - batcher: BatchGenerator = batch_data, - history: None = None, - callbacks: Iterable[Callback] | None = None, - key: PRNGKeyArray, -) -> tuple[T, HistoryCallback]: ... -@overload -def fit[T: eqx.Module, H: Callback]( - model: T, - data: PyTree[Any], - *, - batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, - validation_data: PyTree[Any] = None, - steps: int = 1000, - loss_fn: Loss = mse, - optimizer: optax.GradientTransformation = optax.adam(1e-3), - init_opt_state: PyTree[Any] = None, - batcher: BatchGenerator = batch_data, - history: H, - callbacks: Iterable[Callback] | None = None, - key: PRNGKeyArray, -) -> tuple[T, H]: ... -def fit[T: eqx.Module, H: Callback]( - model: T, - data: PyTree[Any], - *, - batch_size: int = 32, - batch_axis: PyTree[int | None] = 0, - validation_data: PyTree[Any] = None, - steps: int = 1000, - loss_fn: Loss = mse, - optimizer: optax.GradientTransformation = optax.adam(1e-3), - init_opt_state: PyTree[Any] = None, - batcher: BatchGenerator = batch_data, - history: HistoryCallback | H | None = None, - callbacks: Iterable[Callback] | None = None, - key: PRNGKeyArray, -) -> tuple[T, HistoryCallback | H]: - """Trains a model using an optimizer from optax. - - Args: - model: The model instance, which should be trained. It must be a - subclass of `equinox.Module`. The model may contain - [`klax.Unwrappable`][] wrappers. - data: The training data can be any `PyTree` with `ArrayLike` leaves. - Most likely you'll want `data` to be a tuple `(x, y)` with model - inputs `x` and model outputs `y`. - batch_size: The number of examples in a batch. - batch_axis: A `PyTree` denoting, which axis is the batch axis for - arrays in `data`. `batch_axis` must be a prefix of `data`. By - specifying `batch_axis` as a `PyTree` it is possible to specify - different batch axes for different leaves of `data`. (Defaults to - `0`, meaning the first axes of arrays in `data` are batch - dimensions.) - validation_data: Arbitrary `PyTree` used for validation during - training. Must have the same tree structure as `data`. (Defaults - to None.) - steps: Number of gradient updates to apply. (Defaults to 1000.) - loss_fn: The loss function with call signature - `(model: PyTree, data: PyTree, batch_axis: int | None | - Sequence[Any]) -> float`. (Defaults to `mse`.) - optimizer: The optimizer. Any optax gradient transform to calculate - the updates for the model. (Defaults to optax.adam(1e-3).) - init_opt_state: The initial state of the optimizer. If `None`, the - optimizer is initialized from scratch. By providing a value for - `init_opt_state`, the user can resume training from a previous - state (e.g., obtained from the `HistoryCallback.last_opt_state`). - (Defaults to `None`.) - batcher: The data loader that splits inputs and targets into batches. - (Defaults to `batch_data`.) - history: A callback intended for tracking the training process. If no - custom callback is passed the [`klax.HistoryCallback`][] with a - logging interval of 100 steps is used. To change the logging - increment or verbosity of this default callback, pass a - `HistoryCallback` object to this argument, e.g., - `history=HistoryCallback(log_every=10, verbose=False)` for logging - on every 10-th step without printing the loss. - callbacks: Callback functions that are evaluated after every training - step. They can be used to implement early stopping, custom history - logging and more. The argument to the callback function is a - CallbackArgs object. (Defaults to `None`. Keyword only Argument) - key: A `jax.random.PRNGKey` used to provide randomness for batch - generation. (Keyword only argument.) +from ._wrappers import apply - Note: - This function assumes that the batch dimension is always oriented along - the first axes of any `jax.Array` - - Returns: - A tuple of the trained model and the loss history. - - """ - # Braodcast the batch_axis to the data. While this happens again in the - # batch_data, doing it here allows the use of the broadcasted batch_axis in - # the loss function. If `batch_axis` is a prefix of `data`, this ensures - # that only leafs of type ArrayLike are vmapped. Thus it is possible to - # have data like `(str, array)` ans still use `batch_axis=0` instead of - # `batch_axis=(None, 0)`. - batch_axis, dataset_size = broadcast_and_get_size(data, batch_axis) - - # Define a function to calculate the loss. This is jit compiled to speed up - # the loss evaluation for the loss history. - @eqx.filter_jit - def combined_loss(model, batch): - model = unwrap(model) - return loss_fn(model, batch, batch_axis=batch_axis) - - # This partitioned loss function is required within the make_step function, - # because the optax.lbgfs GradientTransformation required the loss function - # to be diretly dependent on the parameters. - def partitioned_loss(params, static, batch): - model = eqx.combine(params, static) - return combined_loss(model, batch) +def fit_core( + updater: Updater, + batcher: BatchGenerator, + ctx: TrainingContext, + steps: int, + callbacks: Iterable[Callback] | Callback | None = None, +) -> tuple[TrainingContext, list[Callback]]: @eqx.filter_jit - def make_step(batch, flat_model, optimizer, flat_opt_state): + def make_step(batch, flat_model, flat_opt_state): # Use the unflatten trick to speed up training, # see https://docs.kidger.site/equinox/tricks/ - model = jax.tree_util.tree_unflatten(treedef_model, flat_model) + model = jax.tree_util.tree_unflatten(ctx.treedef_model, flat_model) opt_state = jax.tree_util.tree_unflatten( - treedef_opt_state, flat_opt_state + ctx.treedef_opt_state, flat_opt_state ) # Compute and apply the parameter updates - params, static = eqx.partition(model, eqx.is_inexact_array) - value, grad = jax.value_and_grad(partitioned_loss)( - params, static, batch - ) - updates, opt_state = optimizer.update( - grad, - opt_state, - params, - value=value, - grad=grad, - value_fn=jax.tree_util.Partial( - partitioned_loss, static=static, batch=batch - ), - ) - params = optax.apply_updates(params, updates) - model = eqx.combine(params, static) + # params, static = eqx.partition(model, eqx.is_inexact_array) + # params = updater(params, static, batch, opt_state) + model, opt_state = updater(model, batch, opt_state) # Apply the Constraint in the model to ensure apply-constrains are met # after the update. @@ -198,65 +72,129 @@ def make_step(batch, flat_model, optimizer, flat_opt_state): return flat_model, flat_opt_state - if init_opt_state is None: - # Initialize the optimizer and 'tell it' to optimize with respect to - # all inexact arrays in the model. This is done by passing the model to - # the optimizer. - opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array)) - else: - opt_state = init_opt_state - - # Apply the Constraint in the model to ensure apply-constrains are met - # initially - model = apply(model) - - # Use the unflatten trick to speed up training, - # see https://docs.kidger.site/equinox/tricks/ - flat_model, treedef_model = jax.tree.flatten(model) - flat_opt_state, treedef_opt_state = jax.tree.flatten(opt_state) - # Make callbacks iterable - callbacks = [] if callbacks is None else list(callbacks) - - # Initialize callback arguments and history - if history is None: - history = HistoryCallback(log_every=100) - callbacks.append(history) - - cbargs = CallbackArgs( - combined_loss, treedef_model, treedef_opt_state, data, validation_data - ) + if callbacks is None: + callbacks = [] + elif isinstance(callbacks, Callback): + callbacks = [callbacks] + else: + callback = list(callbacks) - # Call callbacks after training - cbargs.update(flat_model, flat_opt_state, 0) + # Initialize context and callbacks + ctx.update(ctx.flat_model, ctx.flat_opt_state) for callback in callbacks: - callback.on_training_start(cbargs) + callback.on_training_start(ctx) + + for _ in range(steps): + next(batcher) # Prime the batcher + batch = batcher.send(ctx) # Send the context - # Loop over all training steps - for step, batch in zip( - range(1, steps + 1), - batcher(data, batch_size, batch_axis, key=key), - ): flat_model, flat_opt_state = make_step( - batch, flat_model, optimizer, flat_opt_state + batch, ctx.flat_model, ctx.flat_opt_state ) - # Update callbacks arguments with the current state of the model - cbargs.update(flat_model, flat_opt_state, step) + ctx.update(flat_model, flat_opt_state) # Run all callbacks and break if any of them request termination of # the training loop. # Note! The square brackets are important. Otherwise the loop is # terminated with the first callback that returns true. But we want # to run all callbacks first and then decide, whether to terminate. - if any([callback(cbargs) for callback in callbacks]): + if any([callback(ctx) for callback in callbacks]): break - model = jax.tree_util.tree_unflatten(treedef_model, flat_model) - # Call callbacks after training - cbargs.update(flat_model, flat_opt_state, -1) + ctx.update(ctx.flat_model, ctx.flat_opt_state) for callback in callbacks: - callback.on_training_end(cbargs) + callback.on_training_end(ctx) + + return ctx, list(callbacks) + + +def fit[T: eqx.Module]( + model: T, + data, + *, + batch_size: int = 32, + batch_axis: PyTree[int | None] = 0, + validation_data: PyTree[Any] = None, + steps: int = 1_000, + loss: LossFactory = mse, + optimizer: optax.GradientTransformation, + init_opt_state: PyTree[Any] = None, + batcher: Batcher = batch_data, + updater: UpdaterFactory = optax_transform_update_fn_updater, + callbacks: Iterable[Callback] | Callback | None = HistoryCallback(), + key: PRNGKeyArray, +) -> tuple[T, list[Callback]]: + value_fn, value_and_grad_fn = loss(batch_axis) + evaluator = EvaluationContext(value_fn, data, val_data=validation_data) + state = TrainingState( + model=model, + opt_state=optimizer.init(eqx.filter(model, eqx.is_inexact_array)) + if init_opt_state is None + else init_opt_state + if init_opt_state is None + else init_opt_state, + ) + ctx = TrainingContext( + state=state, + evaluator=evaluator, + timing=TimingInfo(), + ) + + ctx, callbacks = fit_core( + updater(optimizer.update, value_fn, value_and_grad_fn), + batcher( + data=data, batch_axis=batch_axis, batch_size=batch_size, key=key + ), + ctx, + steps, + callbacks=callbacks, + ) + return ctx.model, callbacks + + +if __name__ == "__main__": + # Test fit + x = jnp.linspace(0.0, 1.0, 2).reshape(-1, 1) + y = 2.0 * x + 1.0 + model = eqx.nn.Linear(1, 1, key=eqx.internal.GetKey()()) + model, _ = fit( + model, (x, y), optimizer=optax.adam(1.0), key=eqx.internal.GetKey()() + ) + y_pred = jax.vmap(model)(x) + assert jnp.allclose(y_pred, y) + + # Test fit_core + x = jnp.linspace(0.0, 1.0, 2).reshape(-1, 1) + y = 2.0 * x + 1.0 + data = (x, y) + model = eqx.nn.Linear(1, 1, key=eqx.internal.GetKey()()) + batch_axis = 0 + optimizer = optax.adam(1.0) + value_fn, value_and_grad_fn = mse(batch_axis) + evaluator = EvaluationContext(value_fn, (x, y)) + state = TrainingState( + model=model, + opt_state=optimizer.init(eqx.filter(model, eqx.is_inexact_array)), + ) + ctx = TrainingContext( + state=state, evaluator=evaluator, timing=TimingInfo() + ) + batcher = batch_data( + (x, y), + batch_size=32, + batch_axis=batch_axis, + key=eqx.internal.GetKey()(), # Unused + ) + updater = optax_transform_update_fn_updater( + optimizer.update, value_fn, value_and_grad_fn + ) + ctx, _ = fit_core(updater, batcher, ctx, steps=1000) + y_pred = jax.vmap(ctx.model)(x) + assert jnp.allclose(y_pred, y) + + import pprint - return model, history + pprint.pp(state) diff --git a/klax/_updaters.py b/klax/_updaters.py new file mode 100644 index 0000000..475fc41 --- /dev/null +++ b/klax/_updaters.py @@ -0,0 +1,67 @@ +from abc import abstractmethod +from typing import Protocol + +import equinox as eqx +import jax +import optax +from jaxtyping import PyTree + +from ._losses import ValueAndGradFn, ValueFn + + +class Updater(Protocol): + @abstractmethod + def __call__( + self, model: PyTree, batch: PyTree, opt_state: PyTree + ) -> tuple[PyTree, PyTree]: + pass + + +class UpdaterFactory(Protocol): + @abstractmethod + def __call__( + self, + opt_update: optax.TransformUpdateFn | optax.TransformUpdateExtraArgsFn, + value_fn: ValueFn, + value_and_grad_fn: ValueAndGradFn, + ) -> Updater: + pass + + +def optax_transform_update_fn_updater( + opt_update: optax.TransformUpdateFn, + value_fn: ValueFn, + value_and_grad_fn: ValueAndGradFn, +) -> Updater: + def wrapper(model, batch, opt_state): + _, grad = value_and_grad_fn(model, batch) + updates, opt_state = opt_update( + grad, + opt_state, + eqx.filter(model, eqx.is_inexact_array), + ) + model = eqx.apply_updates(model, updates) + return model, opt_state + + return wrapper + + +def optax_transform_update_fn_extra_args_updater( + opt_update: optax.TransformUpdateExtraArgsFn, + value_fn: ValueFn, + value_and_grad_fn: ValueAndGradFn, +) -> Updater: + def wrapper(model, batch, opt_state): + value, grad = value_and_grad_fn(model, batch) + updates, opt_state = opt_update( + grad, + opt_state, + eqx.filter(model, eqx.is_inexact_array), + value=value, + grad=grad, + value_fn=jax.tree_util.Partial(value_fn, model=model, batch=batch), + ) + model = eqx.apply_updates(model, updates) + return model, opt_state + + return wrapper