Skip to content

Commit 236a03e

Browse files
committed
feat(tests): replace manual sp_group retrieval with module attribute
- Use `sequence_parallel._sp_group` directly instead of calling `_get_sp_group_from_device_mesh` - Simplifies test setup by relying on internal module state after `_setup_sp`
1 parent be4b33d commit 236a03e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/sequence_parallel/test_sequence_parallel_single_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _run_worker_single_attn(rank: int, world_size: int, port: int, padding: bool
181181
sp_size = world_size
182182
device_mesh = DeviceMesh.from_sizes(dp_size=world_size, ulysses_size=sp_size, device_type="cuda")
183183
_setup_sp(device_mesh, sp_size)
184-
sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size)
184+
sp_group = sequence_parallel._sp_group
185185

186186
batch_size = 2
187187
unpad_seq_len = 127 if padding else 128
@@ -271,7 +271,7 @@ def _run_worker_single_attn_fsdp(rank: int, world_size: int, port: int):
271271
# For FSDP+SP, SP is derived from dp/fsdp ranks. Use fsdp=world, dp=1.
272272
device_mesh = DeviceMesh.from_sizes(fsdp_size=world_size, dp_size=1, ulysses_size=sp_size, device_type="cuda")
273273
_setup_sp(device_mesh, sp_size)
274-
sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size)
274+
sp_group = sequence_parallel._sp_group
275275

276276
batch_size = 2
277277
unpad_seq_len = 128

0 commit comments

Comments
 (0)