Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/api/tools.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
title: Model tools
---

## Extracting model information

::: klax.count_parameters
1 change: 1 addition & 0 deletions klax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ._serialization import (
text_serialize_filter_spec as text_serialize_filter_spec,
)
from ._tools import count_parameters
from ._training import fit as fit
from ._wrappers import Constraint as Constraint
from ._wrappers import NonNegative as NonNegative
Expand Down
41 changes: 41 additions & 0 deletions klax/_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import operator

import equinox as eqx
import jax
from jaxtyping import PyTree

from ._wrappers import NonTrainable


def count_parameters(model: PyTree) -> int:
"""Count the number of trainable parameters in a model.

Under the hood this just counts the number of inexact
JAX/NumPy array elements in the pytree that are not
wrapped by [`klax.NonTrainable`][].

Warning:
If you use `jax.lax.stop_gradient` or any other method
besides [`klax.NonTrainable`][] to make arrays not receive
gradient updates, then this function will overestimate
the number of trainable parameters!
Consider using [`klax.NonTrainable`][] or counting the trainable
parameters manually.

Args:
model: Arbitrary pytree.

Returns:
Integer count of inexact inexact JAX/NumPy array elements.

"""
return jax.tree.reduce(
operator.add,
jax.tree.map(
lambda x: x.size
if eqx.is_inexact_array(x) and not isinstance(x, NonTrainable)
else None,
model,
is_leaf=lambda x: isinstance(x, NonTrainable),
),
)
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ nav:
- api/losses.md
- api/callbacks.md
- api/serialization.md
- api/tools.md
- api/initialization.md
- Neural Networks:
- api/nn/linear.md
Expand Down
32 changes: 32 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import equinox as eqx
import numpy as np
from jax import numpy as jnp
from jaxtyping import Array

from klax import NonTrainable, count_parameters


class TestParameterCount:
def test_pytree_with_jax_leaves(self):
tree = (jnp.ones((3, 3)), jnp.zeros((2, 3)))
assert count_parameters(tree) == 15

def test_pytree_with_numpy_leaves(self):
tree = (np.ones((3, 3)), np.zeros((2, 3)))
assert count_parameters(tree) == 15

def test_on_equinox_module(self):
class DummyModel(eqx.Module):
no_parameter: int
weight: tuple[Array, Array]
non_trainable: NonTrainable

def __init__(
self,
):
self.no_parameter = 1
self.weight = (jnp.ones((3, 3)), jnp.zeros((3, 3)))
self.non_trainable = NonTrainable(jnp.ones((2, 2)))

dummy_model = DummyModel()
assert count_parameters(dummy_model) == 18
Loading