Commit 2dc5ff8
authored
sequence parallel fix bug (#47)
* feat(tests): add sequence parallel single attention test
Add a new test file `test_sequence_parallel_single_attention.py` to verify the correctness of the sequence parallel attention implementation. The test includes a distributed setup using torch.distributed and compares outputs between sequence parallel and local attention modes. Also adds an empty `__init__.py` to the transformers test directory for proper module imports.
* wip
* feat(tests): enhance sequence parallel attention test determinism
- Add `_enable_strict_determinism` helper to disable TF32 and enable deterministic algorithms
- Add `_to_local` helper to unwrap DTensors for gradient comparison
- Update test to use full world size for sequence parallel group and increase head count
- Switch to float32 dtype for stricter numerical alignment
- Improve gradient comparison by cloning and unwrapping tensors
* remove __init__
* feat(sequence_parallel): refactor config handling and remove padding-free logic
- Replace HfConfigFactory utility with direct get_config_attr function
- Move get_llm_model to shared transformers utilities
- Remove padding_free parameter and related conditional logic
- Simplify attention mask construction for padded tokens
- Update SequenceParallelConfig to drop padding_free field
* feat(sequence_parallel): enforce flash_attention_2 for packed batches
- Add detection of packed batches via `_is_packed_position_ids` heuristic
- Raise RuntimeError when SDPA backend is used with packed batches, as SDPA lacks native packed/varlen support
- Build 2D attention_mask for padded sequences to ensure correct FlashAttention2 unpad behavior
- Avoid unnecessary 4D causal mask generation for packed/padding-free batches
* feat(sft): add single controller SP packing example for Qwen2.5-7B
Introduce a new cookbook script demonstrating supervised fine-tuning with a single controller using sequence parallelism (SP) and FSDP across 4 GPUs. The example includes:
- Device mesh configuration with dp=2 and fsdp=2 dimensions
- PackingDataset setup with self-cognition data and left truncation
- Training loop with LoRA adapter, AdamW optimizer, and periodic evaluation
- Checkpoint saving based on loss improvement
- Validation of FSDP + SP input slicing across multiple GPUs
* fix loss computation bug
* feat(cookbook): add single controller SP example and reorganize transformers cookbook
- Add new single_controller_sp.py example demonstrating FSDP + SP validation over 4 GPUs
- Move legacy single_controller_sp.py to transformers/sp_fsdp_dense.py
- Add shell script sp_fsdp_dense.sh for running the example
- Update imports and structure to use twinkle framework components
* refactor(tests): move sequence parallel attention test to dedicated directory
Relocate test_sequence_parallel_single_attention.py from tests/transformers/ to tests/sequence_parallel/ to better organize test files by feature area. This improves maintainability and aligns with the project's test structure conventions.
* feat: add sequence parallelism instructions and clean up imports
- Add bash script header and comments to `sp_fsdp_dense.sh` explaining how to enable sequence parallelism with ulysses_size
- Remove duplicate `import os` statement in transformers.py for cleaner code
- Fix minor formatting by removing extra blank line in transformers_utils.py
* refactor
* feat: update training script with local mode and evaluation
- Switch from `ray` to `local` mode for twinkle initialization
- Add evaluation function with separate dataset slice
- Increase dataset size from 100 to 500 samples
- Add cosine warmup learning rate scheduler
- Remove unused torch import and remote_group parameters
- Adjust batch size from 4 to 8 and logging frequency to every 20 steps
- Improve logging with train configs and total steps information
* feat(transformers): remove unused imports in sequence_parallel module
Removed unnecessary imports (`math`, `os`, `SimpleNamespace`) from the sequence_parallel strategy file to clean up the codebase and improve maintainability.1 parent ec6016f commit 2dc5ff8
File tree
6 files changed
+611
-111
lines changed- cookbook/transformers
- src/twinkle
- model/transformers
- strategy
- utils
- tests/sequence_parallel
6 files changed
+611
-111
lines changedLines changed: 32 additions & 41 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | 3 | | |
5 | 4 | | |
6 | 5 | | |
| |||
12 | 11 | | |
13 | 12 | | |
14 | 13 | | |
| 14 | + | |
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
| |||
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
33 | | - | |
| 33 | + | |
34 | 34 | | |
35 | | - | |
36 | 35 | | |
37 | 36 | | |
38 | 37 | | |
39 | 38 | | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
40 | 50 | | |
41 | 51 | | |
42 | 52 | | |
43 | | - | |
| 53 | + | |
44 | 54 | | |
45 | 55 | | |
46 | 56 | | |
47 | | - | |
48 | | - | |
49 | | - | |
| 57 | + | |
50 | 58 | | |
51 | 59 | | |
52 | 60 | | |
53 | 61 | | |
54 | | - | |
55 | | - | |
56 | | - | |
57 | | - | |
58 | | - | |
59 | | - | |
60 | | - | |
61 | | - | |
62 | | - | |
63 | | - | |
64 | | - | |
65 | | - | |
66 | | - | |
67 | | - | |
68 | | - | |
69 | | - | |
70 | | - | |
71 | 62 | | |
72 | 63 | | |
73 | 64 | | |
74 | | - | |
| 65 | + | |
75 | 66 | | |
76 | | - | |
77 | 67 | | |
78 | 68 | | |
79 | 69 | | |
80 | 70 | | |
81 | 71 | | |
82 | 72 | | |
83 | | - | |
84 | 73 | | |
85 | 74 | | |
86 | 75 | | |
87 | 76 | | |
88 | 77 | | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
89 | 88 | | |
90 | | - | |
91 | 89 | | |
92 | | - | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
| 90 | + | |
97 | 91 | | |
98 | | - | |
99 | | - | |
100 | | - | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
105 | 96 | | |
106 | 97 | | |
107 | 98 | | |
108 | | - | |
| 99 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
0 commit comments