diff --git a/docs/api/tools.md b/docs/api/tools.md new file mode 100644 index 0000000..5906413 --- /dev/null +++ b/docs/api/tools.md @@ -0,0 +1,7 @@ +--- +title: Model tools +--- + +## Extracting model information + +::: klax.count_parameters \ No newline at end of file diff --git a/klax/__init__.py b/klax/__init__.py index c0974a9..77457e2 100644 --- a/klax/__init__.py +++ b/klax/__init__.py @@ -35,6 +35,7 @@ from ._serialization import ( text_serialize_filter_spec as text_serialize_filter_spec, ) +from ._tools import count_parameters from ._training import fit as fit from ._wrappers import Constraint as Constraint from ._wrappers import NonNegative as NonNegative diff --git a/klax/_tools.py b/klax/_tools.py new file mode 100644 index 0000000..e6e6957 --- /dev/null +++ b/klax/_tools.py @@ -0,0 +1,41 @@ +import operator + +import equinox as eqx +import jax +from jaxtyping import PyTree + +from ._wrappers import NonTrainable + + +def count_parameters(model: PyTree) -> int: + """Count the number of trainable parameters in a model. + + Under the hood this just counts the number of inexact + JAX/NumPy array elements in the pytree that are not + wrapped by [`klax.NonTrainable`][]. + + Warning: + If you use `jax.lax.stop_gradient` or any other method + besides [`klax.NonTrainable`][] to make arrays not receive + gradient updates, then this function will overestimate + the number of trainable parameters! + Consider using [`klax.NonTrainable`][] or counting the trainable + parameters manually. + + Args: + model: Arbitrary pytree. + + Returns: + Integer count of inexact inexact JAX/NumPy array elements. + + """ + return jax.tree.reduce( + operator.add, + jax.tree.map( + lambda x: x.size + if eqx.is_inexact_array(x) and not isinstance(x, NonTrainable) + else None, + model, + is_leaf=lambda x: isinstance(x, NonTrainable), + ), + ) diff --git a/mkdocs.yml b/mkdocs.yml index 4b355a3..e1c0a6a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -104,6 +104,7 @@ nav: - api/losses.md - api/callbacks.md - api/serialization.md + - api/tools.md - api/initialization.md - Neural Networks: - api/nn/linear.md diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..b0d5702 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,32 @@ +import equinox as eqx +import numpy as np +from jax import numpy as jnp +from jaxtyping import Array + +from klax import NonTrainable, count_parameters + + +class TestParameterCount: + def test_pytree_with_jax_leaves(self): + tree = (jnp.ones((3, 3)), jnp.zeros((2, 3))) + assert count_parameters(tree) == 15 + + def test_pytree_with_numpy_leaves(self): + tree = (np.ones((3, 3)), np.zeros((2, 3))) + assert count_parameters(tree) == 15 + + def test_on_equinox_module(self): + class DummyModel(eqx.Module): + no_parameter: int + weight: tuple[Array, Array] + non_trainable: NonTrainable + + def __init__( + self, + ): + self.no_parameter = 1 + self.weight = (jnp.ones((3, 3)), jnp.zeros((3, 3))) + self.non_trainable = NonTrainable(jnp.ones((2, 2))) + + dummy_model = DummyModel() + assert count_parameters(dummy_model) == 18