diff --git a/tests/sequence_parallel/test_sequence_parallel_single_attention.py b/tests/sequence_parallel/test_sequence_parallel_single_attention.py index dde6b387..ef4c29fc 100644 --- a/tests/sequence_parallel/test_sequence_parallel_single_attention.py +++ b/tests/sequence_parallel/test_sequence_parallel_single_attention.py @@ -19,7 +19,6 @@ from twinkle.model.transformers.strategy.sequence_parallel import ( DistributedAttention, - _get_sp_group_from_device_mesh, sequence_parallel, ) from twinkle.model.transformers.strategy import NativeFSDPStrategy @@ -181,7 +180,7 @@ def _run_worker_single_attn(rank: int, world_size: int, port: int, padding: bool sp_size = world_size device_mesh = DeviceMesh.from_sizes(dp_size=world_size, ulysses_size=sp_size, device_type="cuda") _setup_sp(device_mesh, sp_size) - sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size) + sp_group = sequence_parallel._sp_group batch_size = 2 unpad_seq_len = 127 if padding else 128 @@ -271,7 +270,7 @@ def _run_worker_single_attn_fsdp(rank: int, world_size: int, port: int): # For FSDP+SP, SP is derived from dp/fsdp ranks. Use fsdp=world, dp=1. device_mesh = DeviceMesh.from_sizes(fsdp_size=world_size, dp_size=1, ulysses_size=sp_size, device_type="cuda") _setup_sp(device_mesh, sp_size) - sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size) + sp_group = sequence_parallel._sp_group batch_size = 2 unpad_seq_len = 128