diff --git a/jraph/_src/models.py b/jraph/_src/models.py index e6a2562..dd31752 100644 --- a/jraph/_src/models.py +++ b/jraph/_src/models.py @@ -547,19 +547,26 @@ def _ApplyGCN(graph): nodes = update_node_fn(nodes) # Equivalent to jnp.sum(n_node), but jittable total_num_nodes = tree.tree_leaves(nodes)[0].shape[0] + + # Handle None senders and receivers by initializing empty arrays + if senders is None: + senders = jnp.array([], dtype=jnp.int32) + if receivers is None: + receivers = jnp.array([], dtype=jnp.int32) + if add_self_edges: - # We add self edges to the senders and receivers so that each node - # includes itself in aggregation. - # In principle, a `GraphsTuple` should partition by n_edge, but in - # this case it is not required since a GCN is agnostic to whether - # the `GraphsTuple` is a batch of graphs or a single large graph. - conv_receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), + # We add self edges to the senders and receivers so that each node + # includes itself in aggregation. + # In principle, a `GraphsTuple` should partition by n_edge, but in + # this case it is not required since a GCN is agnostic to whether + # the `GraphsTuple` is a batch of graphs or a single large graph. + conv_receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), axis=0) - conv_senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), + conv_senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), axis=0) else: - conv_senders = senders - conv_receivers = receivers + conv_senders = senders + conv_receivers = receivers # pylint: disable=g-long-lambda if symmetric_normalization: