From d125aba5d5b9895205eb089faf954288c9f9e4c7 Mon Sep 17 00:00:00 2001 From: Fabian Roth Date: Mon, 24 Nov 2025 17:10:07 +0100 Subject: [PATCH 1/4] Added `parameter_count` to estimate the number of trainable parameters ina a model. Added docs and tests. --- docs/api/tools.md | 7 +++++++ klax/__init__.py | 1 + klax/_tools.py | 40 ++++++++++++++++++++++++++++++++++++++++ mkdocs.yml | 2 ++ tests/test_tools.py | 25 +++++++++++++++++++++++++ 5 files changed, 75 insertions(+) create mode 100644 docs/api/tools.md create mode 100644 klax/_tools.py create mode 100644 tests/test_tools.py diff --git a/docs/api/tools.md b/docs/api/tools.md new file mode 100644 index 0000000..98be3a1 --- /dev/null +++ b/docs/api/tools.md @@ -0,0 +1,7 @@ +--- +title: Model tools +--- + +## Extracting model information + +::: klax.parameter_count \ No newline at end of file diff --git a/klax/__init__.py b/klax/__init__.py index 385282b..5c12c17 100644 --- a/klax/__init__.py +++ b/klax/__init__.py @@ -52,6 +52,7 @@ from ._serialization import ( text_serialize_filter_spec as text_serialize_filter_spec, ) +from ._tools import parameter_count from ._training import fit as fit from ._wrappers import ( Constraint as Constraint, diff --git a/klax/_tools.py b/klax/_tools.py new file mode 100644 index 0000000..1645c91 --- /dev/null +++ b/klax/_tools.py @@ -0,0 +1,40 @@ +import equinox as eqx +import jax +from jaxtyping import PyTree + +from ._wrappers import NonTrainable + + +def parameter_count(model: PyTree) -> int: + """Count the number of trainable parameters in a model. + + Under the hood this just counts the number of inexact + JAX/NumPy array elements in the pytree that are not + wrapped by [`klax.NonTrainable`][]. + + Warning: + If you use `jax.lax.stop_gradient` or any other method + besides [`klax.NonTrainable`][] to make arrays not receive + gradient updates, then this function will overestimate + the number of trainable parameters! + Consider using [`klax.NonTrainable`][] or counting the trainable + parameters manually. + + Args: + model: Arbitrary pytree. + + Returns: + Integer count of inexact inexact JAX/NumPy array elements. + + """ + return sum( + jax.tree.flatten( + jax.tree.map( + lambda x: x.size + if eqx.is_inexact_array(x) and not isinstance(x, NonTrainable) + else 0, + model, + is_leaf=lambda x: isinstance(x, NonTrainable), + ) + )[0] + ) diff --git a/mkdocs.yml b/mkdocs.yml index d78f75c..58227c3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -103,6 +103,8 @@ nav: - api/losses.md - api/callbacks.md - api/serialization.md + - api/initialization.md + - api/tools.md - Neural Networks: - api/nn/linear.md - api/nn/mlp.md diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..faa1413 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,25 @@ +import equinox as eqx +from jax import numpy as jnp +from jaxtyping import Array + +from klax import NonTrainable, parameter_count + + +def test_parameter_count(): + tree = (jnp.ones((3, 3)), jnp.zeros((2, 3))) + assert parameter_count(tree) == 15 + + class DummyModel(eqx.Module): + no_parameter: int + weight: tuple[Array, Array] + non_trainable: NonTrainable + + def __init__( + self, + ): + self.no_parameter = 1 + self.weight = (jnp.ones((3, 3)), jnp.zeros((3, 3))) + self.non_trainable = NonTrainable(jnp.ones((2, 2))) + + dummy_model = DummyModel() + assert parameter_count(dummy_model) == 18 From 9410728f9b93183f34d0c3786b05a22d6db1e129 Mon Sep 17 00:00:00 2001 From: Fabian Roth Date: Mon, 24 Nov 2025 17:27:15 +0100 Subject: [PATCH 2/4] Removed erroneous docs api reference. --- mkdocs.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/mkdocs.yml b/mkdocs.yml index 58227c3..48ec707 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -103,7 +103,6 @@ nav: - api/losses.md - api/callbacks.md - api/serialization.md - - api/initialization.md - api/tools.md - Neural Networks: - api/nn/linear.md From 4ac2663082efb77e68fadf5e59affe3414427388 Mon Sep 17 00:00:00 2001 From: Fabian Roth Date: Tue, 27 Jan 2026 11:29:46 +0100 Subject: [PATCH 3/4] Slight refactor of the parameter count function and tests. --- klax/_tools.py | 21 +++++++++++---------- tests/test_tools.py | 32 +++++++++++++++++--------------- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/klax/_tools.py b/klax/_tools.py index 1645c91..0e65a1e 100644 --- a/klax/_tools.py +++ b/klax/_tools.py @@ -1,3 +1,5 @@ +import operator + import equinox as eqx import jax from jaxtyping import PyTree @@ -27,14 +29,13 @@ def parameter_count(model: PyTree) -> int: Integer count of inexact inexact JAX/NumPy array elements. """ - return sum( - jax.tree.flatten( - jax.tree.map( - lambda x: x.size - if eqx.is_inexact_array(x) and not isinstance(x, NonTrainable) - else 0, - model, - is_leaf=lambda x: isinstance(x, NonTrainable), - ) - )[0] + return jax.tree.reduce( + operator.add, + jax.tree.map( + lambda x: x.size + if eqx.is_inexact_array(x) and not isinstance(x, NonTrainable) + else None, + model, + is_leaf=lambda x: isinstance(x, NonTrainable), + ), ) diff --git a/tests/test_tools.py b/tests/test_tools.py index faa1413..ffdaacb 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -5,21 +5,23 @@ from klax import NonTrainable, parameter_count -def test_parameter_count(): - tree = (jnp.ones((3, 3)), jnp.zeros((2, 3))) - assert parameter_count(tree) == 15 +class TestParameterCount: + def test_on_simple_pytree(self): + tree = (jnp.ones((3, 3)), jnp.zeros((2, 3))) + assert parameter_count(tree) == 15 - class DummyModel(eqx.Module): - no_parameter: int - weight: tuple[Array, Array] - non_trainable: NonTrainable + def test_on_equinox_module(self): + class DummyModel(eqx.Module): + no_parameter: int + weight: tuple[Array, Array] + non_trainable: NonTrainable - def __init__( - self, - ): - self.no_parameter = 1 - self.weight = (jnp.ones((3, 3)), jnp.zeros((3, 3))) - self.non_trainable = NonTrainable(jnp.ones((2, 2))) + def __init__( + self, + ): + self.no_parameter = 1 + self.weight = (jnp.ones((3, 3)), jnp.zeros((3, 3))) + self.non_trainable = NonTrainable(jnp.ones((2, 2))) - dummy_model = DummyModel() - assert parameter_count(dummy_model) == 18 + dummy_model = DummyModel() + assert parameter_count(dummy_model) == 18 From bf88fa0b9d96cf2e0c164ad38e50a218c0c75885 Mon Sep 17 00:00:00 2001 From: Fabian Roth Date: Tue, 27 Jan 2026 11:36:19 +0100 Subject: [PATCH 4/4] Renamed function to count_parameters. --- docs/api/tools.md | 2 +- klax/__init__.py | 2 +- klax/_tools.py | 2 +- tests/test_tools.py | 13 +++++++++---- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/api/tools.md b/docs/api/tools.md index 98be3a1..5906413 100644 --- a/docs/api/tools.md +++ b/docs/api/tools.md @@ -4,4 +4,4 @@ title: Model tools ## Extracting model information -::: klax.parameter_count \ No newline at end of file +::: klax.count_parameters \ No newline at end of file diff --git a/klax/__init__.py b/klax/__init__.py index 715f30b..77457e2 100644 --- a/klax/__init__.py +++ b/klax/__init__.py @@ -35,7 +35,7 @@ from ._serialization import ( text_serialize_filter_spec as text_serialize_filter_spec, ) -from ._tools import parameter_count +from ._tools import count_parameters from ._training import fit as fit from ._wrappers import Constraint as Constraint from ._wrappers import NonNegative as NonNegative diff --git a/klax/_tools.py b/klax/_tools.py index 0e65a1e..e6e6957 100644 --- a/klax/_tools.py +++ b/klax/_tools.py @@ -7,7 +7,7 @@ from ._wrappers import NonTrainable -def parameter_count(model: PyTree) -> int: +def count_parameters(model: PyTree) -> int: """Count the number of trainable parameters in a model. Under the hood this just counts the number of inexact diff --git a/tests/test_tools.py b/tests/test_tools.py index ffdaacb..b0d5702 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,14 +1,19 @@ import equinox as eqx +import numpy as np from jax import numpy as jnp from jaxtyping import Array -from klax import NonTrainable, parameter_count +from klax import NonTrainable, count_parameters class TestParameterCount: - def test_on_simple_pytree(self): + def test_pytree_with_jax_leaves(self): tree = (jnp.ones((3, 3)), jnp.zeros((2, 3))) - assert parameter_count(tree) == 15 + assert count_parameters(tree) == 15 + + def test_pytree_with_numpy_leaves(self): + tree = (np.ones((3, 3)), np.zeros((2, 3))) + assert count_parameters(tree) == 15 def test_on_equinox_module(self): class DummyModel(eqx.Module): @@ -24,4 +29,4 @@ def __init__( self.non_trainable = NonTrainable(jnp.ones((2, 2))) dummy_model = DummyModel() - assert parameter_count(dummy_model) == 18 + assert count_parameters(dummy_model) == 18