Skip to content

Commit 08a8934

Browse files
committed
Change doc strings to match jax
Signed-off-by: Dashiell Stander <dstander@protonmail.com>
1 parent 4c96e5b commit 08a8934

1 file changed

Lines changed: 11 additions & 11 deletions

File tree

src/algebraist/fourier.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ def restrict_to_coset(tensor: jax.Array, sn_perms: jax.Array, idx: int) -> jax.A
6868
There are n cosets of S_{n-1} < S_n. Young's Orthogonal Form (YOR) is specifically adapted to the copy of S_{n-1} where the element n is fixed in the nth position. The _cosets_ of this subgroup correspond to the elements that all have n in a given position.
6969
7070
Args:
71-
tensor (torch.Tensor): The function on S_n we are working with, either shape (batch, n!) or (n!,)
72-
sn_perms (torch.Tensor): A tensor-version of S_n with shape (n!, n), each row is the elements 0..n-1 permuted, and the rows are in lexicographic order
71+
tensor (jax.Array): The function on S_n we are working with, either shape (batch, n!) or (n!,)
72+
sn_perms (jax.Array): A tensor-version of S_n with shape (n!, n), each row is the elements 0..n-1 permuted, and the rows are in lexicographic order
7373
idx (int): The index of n that defines the coset we are grabbing
7474
7575
Returns:
76-
torch.Tensor either of shape (batch, (n-1)!) or ((n-1)!, ), depending on whether or not tensor had a batch dimension
76+
jax.Array either of shape (batch, (n-1)!) or ((n-1)!, ), depending on whether or not tensor had a batch dimension
7777
"""
7878
n = sn_perms.shape[1]
7979
fixed_element = n - 1
@@ -88,11 +88,11 @@ def _fourier_projection(fn_vals: jax.Array, irrep: SnIrrep):
8888
number of group elements is small enough that it is easier to rely on the inherent parallelism of PyTorch.
8989
9090
Args:
91-
fn_vals (torch.Tensor): Input tensor of shape (batch_size, n!) or (n!,)
91+
fn_vals (jax.Array): Input tensor of shape (batch_size, n!) or (n!,)
9292
irrep (SnIrrep): an irreducible representation of Sn
9393
9494
Returns:
95-
torch.Tensor: the projection of `fn_vals` onto the irreducible representation given by `irrep`
95+
jax.Array: the projection of `fn_vals` onto the irreducible representation given by `irrep`
9696
"""
9797

9898
matrices = irrep.matrix_tensor()
@@ -112,7 +112,7 @@ def _inverse_fourier_projection(ft: jax.Array, irrep: SnIrrep):
112112
number of group elements is small enough that it is easier to rely on the inherent parallelism of PyTorch.
113113
114114
Args:
115-
ft (torch.Tensor): Input tensor of shape (batch_size, irrep_dim, irrep_dim) or (irrep_dim, irrep_dim)
115+
ft (jax.Array): Input tensor of shape (batch_size, irrep_dim, irrep_dim) or (irrep_dim, irrep_dim)
116116
irrep (SnIrrep): an irreducible representation of Sn
117117
118118
Returns:
@@ -215,11 +215,11 @@ def fourier_projection(fn_vals: jax.Array, irrep: SnIrrep) -> jax.Array:
215215
To get around this we: (1) Use vmap across the given batch dimension of fn_vals (2) Always return the tensor to have the same batch dimension (or none) as fn_vals
216216
217217
Args:
218-
fn_vals (torch.Tensor): A tensor of shape (batch_size, n!) for an integer n
218+
fn_vals (jax.Array): A tensor of shape (batch_size, n!) for an integer n
219219
irrep (SnIrrep): An irreducible representation of Sn, given by an integer partition of n
220220
221221
Returns:
222-
torch.Tensor the projection of `fn_vals` onto `irrep` with shape (batch_size, irrep.dim, irrep.dim)
222+
jax.Array the projection of `fn_vals` onto `irrep` with shape (batch_size, irrep.dim, irrep.dim)
223223
"""
224224
n = irrep.n
225225
if n <= BASE_CASE or irrep.dim == 1:
@@ -286,7 +286,7 @@ def slow_sn_ft(fn_vals: jax.Array, n: int):
286286
Compute the Fourier transform on Sn.
287287
288288
Args:
289-
fn_vals (torch.Tensor): Input tensor of shape (batch_size, n!) or (n!,)
289+
fn_vals (jax.Array): Input tensor of shape (batch_size, n!) or (n!,)
290290
n (int): The order of the symmetric group
291291
292292
Returns:
@@ -319,7 +319,7 @@ def slow_sn_ift(ft, n: int):
319319
n (int): The order of the symmetric group
320320
321321
Returns:
322-
torch.Tensor: The inverse Fourier transform of shape (batch_size, n!)
322+
jax.Array: The inverse Fourier transform of shape (batch_size, n!)
323323
"""
324324
permutations = Permutation.full_group(n)
325325
group_order = len(permutations)
@@ -347,7 +347,7 @@ def slow_sn_fourier_decomposition(ft, n: int):
347347
n (int): The order of the symmetric group
348348
349349
Returns:
350-
torch.Tensor: The inverse Fourier transform of shape (batch_size, n!)
350+
jax.Array: The inverse Fourier transform of shape (batch_size, n!)
351351
"""
352352
permutations = Permutation.full_group(n)
353353
group_order = len(permutations)

0 commit comments

Comments
 (0)