2525
2626
2727BASE_CASE = 5
28+ COSET_INDICES_CACHE = {}
29+
30+
31+ def get_coset_indices (n , idx ):
32+ """
33+ Precompute indices for a specific coset.
34+
35+ Args:
36+ n: Size of the permutation group
37+ idx: Position index for the fixed element
38+
39+ Returns:
40+ Array of indices where element (n-1) is at position idx
41+ """
42+ # This function is only called outside of JIT contexts
43+ sn_perms = generate_all_permutations (n )
44+ fixed_element = n - 1
45+ mask = (sn_perms [:, idx ] == fixed_element )
46+ return jnp .where (mask )[0 ] # Get the actual indices where mask is True
47+
48+
49+ def get_cached_coset_indices (n , idx ):
50+ """Get cached coset indices or compute them if not cached."""
51+ key = (n , idx )
52+ if key not in COSET_INDICES_CACHE :
53+ COSET_INDICES_CACHE [key ] = get_coset_indices (n , idx )
54+ return COSET_INDICES_CACHE [key ]
55+
2856
2957
3058def get_all_irreps (n : int ) -> list [SnIrrep ]:
@@ -56,18 +84,15 @@ def lift_from_coset(lifted_fn: jax.Array, coset_fn: jax.Array, n: int, idx: int)
5684 Returns:
5785 None, operates in place on lifted_fn
5886 """
59- sn_perms = generate_all_permutations (n )
60- fixed_element = n - 1
61-
6287 # Create a boolean mask instead of using argwhere
63- mask = ( sn_perms [: , idx ] == fixed_element )
88+ indices = get_cached_coset_indices ( n , idx )
6489
6590 # Use boolean indexing with .at[] syntax
6691 if coset_fn .ndim > 1 :
6792 # Handle batch dimension
68- lifted_fn = lifted_fn .at [:, mask ].set (coset_fn )
93+ lifted_fn = lifted_fn .at [:, indices ].set (coset_fn )
6994 else :
70- lifted_fn = lifted_fn .at [mask ].set (coset_fn )
95+ lifted_fn = lifted_fn .at [indices ].set (coset_fn )
7196
7297 return lifted_fn
7398
@@ -85,18 +110,15 @@ def restrict_to_coset(tensor: jax.Array, n: int, idx: int) -> jax.Array:
85110 Returns:
86111 jax.Array either of shape (batch, (n-1)!) or ((n-1)!, ), depending on whether tensor had a batch dimension
87112 """
88- sn_perms = generate_all_permutations (n )
89- fixed_element = n - 1
90113
91- # Create a boolean mask instead of using argwhere
92- mask = (sn_perms [:, idx ] == fixed_element )
114+ indices = get_cached_coset_indices (n , idx )
93115
94116 # Use boolean indexing which works with JIT
95117 if tensor .ndim > 1 :
96118 # Handle batch dimension
97- return tensor [:, mask ]
119+ return tensor [:, indices ]
98120 else :
99- return tensor [mask ]
121+ return tensor [indices ]
100122
101123
102124@partial (jax .jit , static_argnums = 1 )
0 commit comments