Skip to content

Commit 2dc5ff8

Browse files
authored
sequence parallel fix bug (#47)
* feat(tests): add sequence parallel single attention test Add a new test file `test_sequence_parallel_single_attention.py` to verify the correctness of the sequence parallel attention implementation. The test includes a distributed setup using torch.distributed and compares outputs between sequence parallel and local attention modes. Also adds an empty `__init__.py` to the transformers test directory for proper module imports. * wip * feat(tests): enhance sequence parallel attention test determinism - Add `_enable_strict_determinism` helper to disable TF32 and enable deterministic algorithms - Add `_to_local` helper to unwrap DTensors for gradient comparison - Update test to use full world size for sequence parallel group and increase head count - Switch to float32 dtype for stricter numerical alignment - Improve gradient comparison by cloning and unwrapping tensors * remove __init__ * feat(sequence_parallel): refactor config handling and remove padding-free logic - Replace HfConfigFactory utility with direct get_config_attr function - Move get_llm_model to shared transformers utilities - Remove padding_free parameter and related conditional logic - Simplify attention mask construction for padded tokens - Update SequenceParallelConfig to drop padding_free field * feat(sequence_parallel): enforce flash_attention_2 for packed batches - Add detection of packed batches via `_is_packed_position_ids` heuristic - Raise RuntimeError when SDPA backend is used with packed batches, as SDPA lacks native packed/varlen support - Build 2D attention_mask for padded sequences to ensure correct FlashAttention2 unpad behavior - Avoid unnecessary 4D causal mask generation for packed/padding-free batches * feat(sft): add single controller SP packing example for Qwen2.5-7B Introduce a new cookbook script demonstrating supervised fine-tuning with a single controller using sequence parallelism (SP) and FSDP across 4 GPUs. The example includes: - Device mesh configuration with dp=2 and fsdp=2 dimensions - PackingDataset setup with self-cognition data and left truncation - Training loop with LoRA adapter, AdamW optimizer, and periodic evaluation - Checkpoint saving based on loss improvement - Validation of FSDP + SP input slicing across multiple GPUs * fix loss computation bug * feat(cookbook): add single controller SP example and reorganize transformers cookbook - Add new single_controller_sp.py example demonstrating FSDP + SP validation over 4 GPUs - Move legacy single_controller_sp.py to transformers/sp_fsdp_dense.py - Add shell script sp_fsdp_dense.sh for running the example - Update imports and structure to use twinkle framework components * refactor(tests): move sequence parallel attention test to dedicated directory Relocate test_sequence_parallel_single_attention.py from tests/transformers/ to tests/sequence_parallel/ to better organize test files by feature area. This improves maintainability and aligns with the project's test structure conventions. * feat: add sequence parallelism instructions and clean up imports - Add bash script header and comments to `sp_fsdp_dense.sh` explaining how to enable sequence parallelism with ulysses_size - Remove duplicate `import os` statement in transformers.py for cleaner code - Fix minor formatting by removing extra blank line in transformers_utils.py * refactor * feat: update training script with local mode and evaluation - Switch from `ray` to `local` mode for twinkle initialization - Add evaluation function with separate dataset slice - Increase dataset size from 100 to 500 samples - Add cosine warmup learning rate scheduler - Remove unused torch import and remote_group parameters - Adjust batch size from 4 to 8 and logging frequency to every 20 steps - Improve logging with train configs and total steps information * feat(transformers): remove unused imports in sequence_parallel module Removed unnecessary imports (`math`, `os`, `SimpleNamespace`) from the sequence_parallel strategy file to clean up the codebase and improve maintainability.
1 parent ec6016f commit 2dc5ff8

File tree

6 files changed

+611
-111
lines changed

6 files changed

+611
-111
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from functools import partial
22
import numpy as np
3-
import torch
43
from peft import LoraConfig
54

65
import twinkle
@@ -12,6 +11,7 @@
1211

1312
logger = get_logger()
1413
MODEL_ID = 'ms://Qwen/Qwen2.5-7B-Instruct'
14+
DATASETS='ms://swift/self-cognition'
1515

1616
device_group = [
1717
DeviceGroup(
@@ -30,79 +30,70 @@
3030
)
3131

3232
twinkle.initialize(
33-
mode="ray",
33+
mode="local",
3434
nproc_per_node=4,
35-
groups=device_group,
3635
global_device_mesh=device_mesh,
3736
lazy_collect=False,
3837
)
3938

39+
def eval(model):
40+
dataloader = DataLoader(
41+
dataset=partial(create_dataset, data_slice=range(100)),
42+
batch_size=4,
43+
device_mesh=device_mesh,
44+
)
45+
for _, batch in enumerate(dataloader):
46+
model.forward_only(inputs=batch, adapter_name="default")
47+
model.calculate_loss(adapter_name="default")
48+
return model.calculate_metric(is_training=False, adapter_name="default")
49+
4050

4151
def create_dataset(data_slice=None):
4252
dataset = Dataset(
43-
dataset_meta=DatasetMeta("ms://swift/self-cognition", data_slice=data_slice)
53+
dataset_meta=DatasetMeta(DATASETS, data_slice=range(500))
4454
)
4555
dataset.set_template(
4656
"Template",
47-
model_id=MODEL_ID,
48-
truncation_strategy="left",
49-
max_length=64,
57+
model_id=MODEL_ID
5058
)
5159
dataset.map(SelfCognitionProcessor("twinkle模型", "twinkle团队"))
5260
dataset.encode(batched=True)
5361
return dataset
54-
55-
56-
def eval(model: TransformersModel):
57-
dataloader = DataLoader(
58-
dataset=partial(create_dataset, data_slice=range(20)),
59-
batch_size=4,
60-
drop_last=True,
61-
device_mesh=device_mesh,
62-
remote_group="default",
63-
)
64-
for step, batch in enumerate(dataloader):
65-
model.forward_only(inputs=batch, adapter_name="default")
66-
model.calculate_loss(adapter_name="default")
67-
metrics = model.calculate_metric(is_training=False, adapter_name="default")
68-
return metrics()
69-
70-
7162
def train():
7263
dataloader = DataLoader(
7364
dataset=partial(create_dataset, data_slice=None),
74-
batch_size=4,
65+
batch_size=8,
7566
device_mesh=device_mesh,
76-
remote_group="default",
7767
)
7868

7969
model = TransformersModel(
8070
model_id=MODEL_ID,
8171
device_mesh=device_mesh,
8272
strategy="native_fsdp",
83-
remote_group="default",
8473
)
8574

8675
lora_config = LoraConfig(target_modules="all-linear")
8776
model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1)
8877
model.set_optimizer("AdamW", lr=1e-4, adapter_name="default")
78+
model.set_lr_scheduler(
79+
scheduler_cls="CosineWarmupScheduler",
80+
num_warmup_steps=5,
81+
num_training_steps=len(dataloader),
82+
adapter_name="default",
83+
)
84+
85+
logger.info(model.get_train_configs(adapter_name="default"))
86+
logger.info(f"Total steps: {len(dataloader)}")
87+
8988

90-
loss_metric = 99.0
9189
for step, batch in enumerate(dataloader):
92-
if isinstance(batch, list) and len(batch) == 0:
93-
continue
94-
output = model.forward_backward(inputs=batch, adapter_name="default")
95-
loss_value = output() if callable(output) else output
96-
logger.info(f"step {step}, loss: {loss_value}")
90+
model.forward_backward(inputs=batch, adapter_name="default")
9791
model.clip_grad_and_step(adapter_name="default")
98-
if step % 50 == 0 and step > 0:
99-
metrics = eval(model)
100-
logger.info(f"Current is step {step} of {len(dataloader)}, metric: {metrics}")
101-
metrics["step"] = step
102-
if loss_metric > metrics["loss"]:
103-
model.save(f"checkpoint-{step}")
104-
loss_metric = metrics["loss"]
92+
if step % 20 == 0:
93+
metric = model.calculate_metric(is_training=True, adapter_name="default")
94+
logger.info(f"Current is step {step} of {len(dataloader)}, metric: {metric}")
95+
model.save("last-checkpoint", interval=1)
10596

10697

10798
if __name__ == "__main__":
108-
train()
99+
train()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
# To enabele sequence parallelism, please set ulysses_size > 1
3+
# device_mesh = DeviceMesh(
4+
# device_type="cuda",
5+
# mesh=np.arange(4).reshape(2, 2),
6+
# mesh_dim_names=("dp", "fsdp"),
7+
# ulysses_size=2,
8+
# )
9+
#
10+
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 sp_fsdp_dense.py

0 commit comments

Comments
 (0)