From a0856d8487c65e549ca6bcf12a4cd40605804c55 Mon Sep 17 00:00:00 2001 From: Giovanni Trezza Date: Wed, 4 Mar 2026 17:33:20 +0100 Subject: [PATCH] Fix metrics in explicit batching --- src/tensorial/gcnn/atomic/_metrics.py | 26 ++++++++++++++++++++------ src/tensorial/gcnn/graph_ops.py | 2 +- src/tensorial/gcnn/metrics.py | 20 ++++++++++++++++++-- 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/src/tensorial/gcnn/atomic/_metrics.py b/src/tensorial/gcnn/atomic/_metrics.py index 75c592e..c2191b6 100644 --- a/src/tensorial/gcnn/atomic/_metrics.py +++ b/src/tensorial/gcnn/atomic/_metrics.py @@ -53,12 +53,26 @@ def get(mapping: Mapping, key: str): ) -AvgNumNeighbours = reax.metrics.Average.from_fun( - lambda graph, *_: ( - jnp.bincount(graph.senders, length=jnp.sum(graph.n_node)), - graph.nodes.get(graph_keys.MASK), - ) -) +def _bincount_neighbours(graph: jraph.GraphsTuple, *_): + """Helper to count neighbours for both implicit and explicit batching.""" + node_mask = graph.nodes.get(graph_keys.MASK) + if graph.senders.ndim == 1: + # Implicit batching (standard jraph) + # Use shape which is static even for tracers if available + num_nodes = node_mask.shape[0] if node_mask is not None else jnp.sum(graph.n_node) + counts = jnp.bincount(graph.senders, length=num_nodes) + else: + # Explicit batching (stacked graphs) + # We use .shape[1] which is a static Python int for stacked arrays. + node_array = node_mask if node_mask is not None else jax.tree.leaves(graph.nodes)[0] + num_nodes = node_array.shape[1] + counts = jax.vmap(lambda s: jnp.bincount(s, length=num_nodes))(graph.senders) + + # Flatten to make it sample-wise for all nodes across the batch + return counts.reshape(-1), node_mask.reshape(-1) if node_mask is not None else None + + +AvgNumNeighbours = reax.metrics.Average.from_fun(_bincount_neighbours) class EnergyPerAtomLstsq(reax.metrics.FromFun): diff --git a/src/tensorial/gcnn/graph_ops.py b/src/tensorial/gcnn/graph_ops.py index 73817e9..7b2f013 100644 --- a/src/tensorial/gcnn/graph_ops.py +++ b/src/tensorial/gcnn/graph_ops.py @@ -294,7 +294,7 @@ def segment_reduce( Returns: The reduced array. Shape (num_segments, D) or (num_segments,). """ - # 3Handle Reduction Type + # Handle Reduction Type try: fn = _REDUCTIONS[reduction] return fn(data, segment_sizes, mask=mask, segment_mask=segment_mask) diff --git a/src/tensorial/gcnn/metrics.py b/src/tensorial/gcnn/metrics.py index 691f95d..cbe1186 100644 --- a/src/tensorial/gcnn/metrics.py +++ b/src/tensorial/gcnn/metrics.py @@ -242,8 +242,24 @@ def _calc_averages(self, graphs: jraph.GraphsTuple, *_) -> Averages: # Transform the type numbers from whatever they are to 0, 1, 2.... types = nn_utils.vwhere(types, self._node_types) - counts = jnp.bincount(graphs.senders, length=jnp.sum(graphs.n_node).item()) - mask = reax.metrics.utils.prepare_mask(counts, graphs.nodes.get(keys.MASK)) + mask = graphs.nodes.get(keys.MASK) + + if graphs.senders.ndim == 1: + # Implicit batching + num_nodes = mask.shape[0] if mask is not None else jnp.sum(graphs.n_node) + counts = jnp.bincount(graphs.senders, length=num_nodes) + else: + # Explicit batching + node_array = mask if mask is not None else jax.tree.leaves(graphs.nodes)[0] + num_nodes = node_array.shape[1] + counts = jax.vmap(lambda s: jnp.bincount(s, length=num_nodes))(graphs.senders) + + # Flatten everything to be consistent for reax.metrics.Average + counts = counts.reshape(-1) + mask = mask.reshape(-1) if mask is not None else None + types = types.reshape(-1) + + mask = reax.metrics.utils.prepare_mask(counts, mask) mask = mask if mask is not None else True num_classes = len(self._node_types)