Skip to content

Commit 0d048d3

Browse files
committed
JIT friendly indexing
Signed-off-by: Dashiell Stander <dstander@protonmail.com>
1 parent 13a8e99 commit 0d048d3

1 file changed

Lines changed: 34 additions & 12 deletions

File tree

src/algebraist/fourier.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,34 @@
2525

2626

2727
BASE_CASE = 5
28+
COSET_INDICES_CACHE = {}
29+
30+
31+
def get_coset_indices(n, idx):
32+
"""
33+
Precompute indices for a specific coset.
34+
35+
Args:
36+
n: Size of the permutation group
37+
idx: Position index for the fixed element
38+
39+
Returns:
40+
Array of indices where element (n-1) is at position idx
41+
"""
42+
# This function is only called outside of JIT contexts
43+
sn_perms = generate_all_permutations(n)
44+
fixed_element = n - 1
45+
mask = (sn_perms[:, idx] == fixed_element)
46+
return jnp.where(mask)[0] # Get the actual indices where mask is True
47+
48+
49+
def get_cached_coset_indices(n, idx):
50+
"""Get cached coset indices or compute them if not cached."""
51+
key = (n, idx)
52+
if key not in COSET_INDICES_CACHE:
53+
COSET_INDICES_CACHE[key] = get_coset_indices(n, idx)
54+
return COSET_INDICES_CACHE[key]
55+
2856

2957

3058
def get_all_irreps(n: int) -> list[SnIrrep]:
@@ -56,18 +84,15 @@ def lift_from_coset(lifted_fn: jax.Array, coset_fn: jax.Array, n: int, idx: int)
5684
Returns:
5785
None, operates in place on lifted_fn
5886
"""
59-
sn_perms = generate_all_permutations(n)
60-
fixed_element = n - 1
61-
6287
# Create a boolean mask instead of using argwhere
63-
mask = (sn_perms[:, idx] == fixed_element)
88+
indices = get_cached_coset_indices(n, idx)
6489

6590
# Use boolean indexing with .at[] syntax
6691
if coset_fn.ndim > 1:
6792
# Handle batch dimension
68-
lifted_fn = lifted_fn.at[:, mask].set(coset_fn)
93+
lifted_fn = lifted_fn.at[:, indices].set(coset_fn)
6994
else:
70-
lifted_fn = lifted_fn.at[mask].set(coset_fn)
95+
lifted_fn = lifted_fn.at[indices].set(coset_fn)
7196

7297
return lifted_fn
7398

@@ -85,18 +110,15 @@ def restrict_to_coset(tensor: jax.Array, n: int, idx: int) -> jax.Array:
85110
Returns:
86111
jax.Array either of shape (batch, (n-1)!) or ((n-1)!, ), depending on whether tensor had a batch dimension
87112
"""
88-
sn_perms = generate_all_permutations(n)
89-
fixed_element = n - 1
90113

91-
# Create a boolean mask instead of using argwhere
92-
mask = (sn_perms[:, idx] == fixed_element)
114+
indices = get_cached_coset_indices(n, idx)
93115

94116
# Use boolean indexing which works with JIT
95117
if tensor.ndim > 1:
96118
# Handle batch dimension
97-
return tensor[:, mask]
119+
return tensor[:, indices]
98120
else:
99-
return tensor[mask]
121+
return tensor[indices]
100122

101123

102124
@partial(jax.jit, static_argnums=1)

0 commit comments

Comments
 (0)