Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions src/tensorial/gcnn/atomic/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,26 @@ def get(mapping: Mapping, key: str):
)


AvgNumNeighbours = reax.metrics.Average.from_fun(
lambda graph, *_: (
jnp.bincount(graph.senders, length=jnp.sum(graph.n_node)),
graph.nodes.get(graph_keys.MASK),
)
)
def _bincount_neighbours(graph: jraph.GraphsTuple, *_):
"""Helper to count neighbours for both implicit and explicit batching."""
node_mask = graph.nodes.get(graph_keys.MASK)
if graph.senders.ndim == 1:
# Implicit batching (standard jraph)
# Use shape which is static even for tracers if available
num_nodes = node_mask.shape[0] if node_mask is not None else jnp.sum(graph.n_node)
counts = jnp.bincount(graph.senders, length=num_nodes)
else:
# Explicit batching (stacked graphs)
# We use .shape[1] which is a static Python int for stacked arrays.
node_array = node_mask if node_mask is not None else jax.tree.leaves(graph.nodes)[0]
num_nodes = node_array.shape[1]
counts = jax.vmap(lambda s: jnp.bincount(s, length=num_nodes))(graph.senders)

# Flatten to make it sample-wise for all nodes across the batch
return counts.reshape(-1), node_mask.reshape(-1) if node_mask is not None else None


AvgNumNeighbours = reax.metrics.Average.from_fun(_bincount_neighbours)


class EnergyPerAtomLstsq(reax.metrics.FromFun):
Expand Down
2 changes: 1 addition & 1 deletion src/tensorial/gcnn/graph_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def segment_reduce(
Returns:
The reduced array. Shape (num_segments, D) or (num_segments,).
"""
# 3Handle Reduction Type
# Handle Reduction Type
try:
fn = _REDUCTIONS[reduction]
return fn(data, segment_sizes, mask=mask, segment_mask=segment_mask)
Expand Down
20 changes: 18 additions & 2 deletions src/tensorial/gcnn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,24 @@ def _calc_averages(self, graphs: jraph.GraphsTuple, *_) -> Averages:
# Transform the type numbers from whatever they are to 0, 1, 2....
types = nn_utils.vwhere(types, self._node_types)

counts = jnp.bincount(graphs.senders, length=jnp.sum(graphs.n_node).item())
mask = reax.metrics.utils.prepare_mask(counts, graphs.nodes.get(keys.MASK))
mask = graphs.nodes.get(keys.MASK)

if graphs.senders.ndim == 1:
# Implicit batching
num_nodes = mask.shape[0] if mask is not None else jnp.sum(graphs.n_node)
counts = jnp.bincount(graphs.senders, length=num_nodes)
else:
# Explicit batching
node_array = mask if mask is not None else jax.tree.leaves(graphs.nodes)[0]
num_nodes = node_array.shape[1]
counts = jax.vmap(lambda s: jnp.bincount(s, length=num_nodes))(graphs.senders)

# Flatten everything to be consistent for reax.metrics.Average
counts = counts.reshape(-1)
mask = mask.reshape(-1) if mask is not None else None
types = types.reshape(-1)

mask = reax.metrics.utils.prepare_mask(counts, mask)
mask = mask if mask is not None else True

num_classes = len(self._node_types)
Expand Down
Loading