Skip to content

Commit 297b312

Browse files
authored
sp fix ci test hang (#52)
* feat(tests): replace manual sp_group retrieval with module attribute Replace calls to `_get_sp_group_from_device_mesh` with direct access to `sequence_parallel._sp_group` in sequence parallel attention tests. This simplifies the test setup by using the already initialized group stored in the module, improving code clarity and reducing redundancy. * feat(tests): remove unused import in sequence parallel test Remove `_get_sp_group_from_device_mesh` import from test file as it is no longer used in the test, cleaning up imports and improving code clarity.
1 parent 9d3402b commit 297b312

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

tests/sequence_parallel/test_sequence_parallel_single_attention.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from twinkle.model.transformers.strategy.sequence_parallel import (
2121
DistributedAttention,
22-
_get_sp_group_from_device_mesh,
2322
sequence_parallel,
2423
)
2524
from twinkle.model.transformers.strategy import NativeFSDPStrategy
@@ -181,7 +180,7 @@ def _run_worker_single_attn(rank: int, world_size: int, port: int, padding: bool
181180
sp_size = world_size
182181
device_mesh = DeviceMesh.from_sizes(dp_size=world_size, ulysses_size=sp_size, device_type="cuda")
183182
_setup_sp(device_mesh, sp_size)
184-
sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size)
183+
sp_group = sequence_parallel._sp_group
185184

186185
batch_size = 2
187186
unpad_seq_len = 127 if padding else 128
@@ -271,7 +270,7 @@ def _run_worker_single_attn_fsdp(rank: int, world_size: int, port: int):
271270
# For FSDP+SP, SP is derived from dp/fsdp ranks. Use fsdp=world, dp=1.
272271
device_mesh = DeviceMesh.from_sizes(fsdp_size=world_size, dp_size=1, ulysses_size=sp_size, device_type="cuda")
273272
_setup_sp(device_mesh, sp_size)
274-
sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size)
273+
sp_group = sequence_parallel._sp_group
275274

276275
batch_size = 2
277276
unpad_seq_len = 128

0 commit comments

Comments
 (0)