Conversation
Add a new test file `test_sequence_parallel_single_attention.py` to verify the correctness of the sequence parallel attention implementation. The test includes a distributed setup using torch.distributed and compares outputs between sequence parallel and local attention modes. Also adds an empty `__init__.py` to the transformers test directory for proper module imports.
- Add `_enable_strict_determinism` helper to disable TF32 and enable deterministic algorithms - Add `_to_local` helper to unwrap DTensors for gradient comparison - Update test to use full world size for sequence parallel group and increase head count - Switch to float32 dtype for stricter numerical alignment - Improve gradient comparison by cloning and unwrapping tensors
…free logic - Replace HfConfigFactory utility with direct get_config_attr function - Move get_llm_model to shared transformers utilities - Remove padding_free parameter and related conditional logic - Simplify attention mask construction for padded tokens - Update SequenceParallelConfig to drop padding_free field
- Add detection of packed batches via `_is_packed_position_ids` heuristic - Raise RuntimeError when SDPA backend is used with packed batches, as SDPA lacks native packed/varlen support - Build 2D attention_mask for padded sequences to ensure correct FlashAttention2 unpad behavior - Avoid unnecessary 4D causal mask generation for packed/padding-free batches
Introduce a new cookbook script demonstrating supervised fine-tuning with a single controller using sequence parallelism (SP) and FSDP across 4 GPUs. The example includes: - Device mesh configuration with dp=2 and fsdp=2 dimensions - PackingDataset setup with self-cognition data and left truncation - Training loop with LoRA adapter, AdamW optimizer, and periodic evaluation - Checkpoint saving based on loss improvement - Validation of FSDP + SP input slicing across multiple GPUs
…formers cookbook - Add new single_controller_sp.py example demonstrating FSDP + SP validation over 4 GPUs - Move legacy single_controller_sp.py to transformers/sp_fsdp_dense.py - Add shell script sp_fsdp_dense.sh for running the example - Update imports and structure to use twinkle framework components
…irectory Relocate test_sequence_parallel_single_attention.py from tests/transformers/ to tests/sequence_parallel/ to better organize test files by feature area. This improves maintainability and aligns with the project's test structure conventions.
- Add bash script header and comments to `sp_fsdp_dense.sh` explaining how to enable sequence parallelism with ulysses_size - Remove duplicate `import os` statement in transformers.py for cleaner code - Fix minor formatting by removing extra blank line in transformers_utils.py
- Switch from `ray` to `local` mode for twinkle initialization - Add evaluation function with separate dataset slice - Increase dataset size from 100 to 500 samples - Add cosine warmup learning rate scheduler - Remove unused torch import and remote_group parameters - Adjust batch size from 4 to 8 and logging frequency to every 20 steps - Improve logging with train configs and total steps information
Removed unnecessary imports (`math`, `os`, `SimpleNamespace`) from the sequence_parallel strategy file to clean up the codebase and improve maintainability.
- Use `sequence_parallel._sp_group` directly instead of calling `_get_sp_group_from_device_mesh` - Simplifies test setup by relying on internal module state after `_setup_sp`
Summary of ChangesHello @meichangsu1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refines the sequence parallelism strategy within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant enhancements and fixes related to sequence parallelism (SP) and FlashAttention2 integration. Key changes include refactoring the sequence_parallel.py module to remove unused imports, simplify config attribute retrieval, and improve handling of packed batches by introducing an is_packed flag. This flag is used to dynamically derive cu_seqlens for FlashAttention2 and to raise an error if SDPA is used with packed batches, as it's not supported. The padding_free concept was removed, and the reduce_loss method in the SP strategy was updated to correctly handle 'sum' and 'mean' reductions for global loss calculation, with a review comment highlighting and correcting an erroneous division by world_size. Additionally, the postprocess_outputs method was refined to ensure compatibility with dict-like ModelOutput containers and to trim gathered logits to their original sequence length. The transformers.py file was updated to dynamically set the loss_reduction in the SP strategy based on the loss instance. A new get_llm_model utility was added to transformers_utils.py for robust LLM module extraction, with review comments suggesting more specific exception handling (ImportError instead of broad Exception). The example script sp_fsdp_dense.py was updated to use 'local' mode for twinkle.initialize, streamline the training and evaluation loops, and add an LR scheduler. Finally, a new test file test_sequence_parallel_single_attention.py was added to validate SP functionality, including padding and FSDP integration, ensuring correctness of forward and backward passes.
| def create_dataset(data_slice=None): | ||
| dataset = Dataset( | ||
| dataset_meta=DatasetMeta("ms://swift/self-cognition", data_slice=data_slice) | ||
| dataset_meta=DatasetMeta(DATASETS, data_slice=range(500)) |
There was a problem hiding this comment.
The data_slice parameter of the create_dataset function is currently being ignored due to the hardcoded range(500). This can lead to unexpected behavior as calls from eval and train with different data_slice values will not have the intended effect. To fix this and restore the expected behavior, the function should use the data_slice parameter that is passed to it.
| dataset_meta=DatasetMeta(DATASETS, data_slice=range(500)) | |
| dataset_meta=DatasetMeta(DATASETS, data_slice=data_slice) |
| if sequence_parallel.world_size > 1: | ||
| out_metric = out.detach() / sequence_parallel.world_size | ||
| return out_metric + (out - out.detach()) |
There was a problem hiding this comment.
The division by sequence_parallel.world_size when calculating out_metric seems incorrect. For 'sum' reduction, out represents the global sum of losses. Dividing this by world_size results in an average of sums, which is not a standard metric. To report the correct global loss for metric tracking, this division should likely be removed.
| if sequence_parallel.world_size > 1: | |
| out_metric = out.detach() / sequence_parallel.world_size | |
| return out_metric + (out - out.detach()) | |
| if sequence_parallel.world_size > 1: | |
| out_metric = out.detach() | |
| return out_metric + (out - out.detach()) |
| if sequence_parallel.world_size > 1: | ||
| out_metric = out.detach() / sequence_parallel.world_size | ||
| return out_metric + (out - out.detach()) |
There was a problem hiding this comment.
The division by sequence_parallel.world_size when calculating out_metric seems incorrect. When using 'mean' reduction, out should already represent the global mean loss. Dividing it again by the world size would lead to an incorrect metric value. This division should probably be removed to ensure the reported loss is accurate.
| if sequence_parallel.world_size > 1: | |
| out_metric = out.detach() / sequence_parallel.world_size | |
| return out_metric + (out - out.detach()) | |
| if sequence_parallel.world_size > 1: | |
| out_metric = out.detach() | |
| return out_metric + (out - out.detach()) |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Using a broad except Exception: can mask unexpected errors and complicate debugging. It is better to catch a more specific exception. In this case, ImportError would be more appropriate, as it's the expected error if the accelerate library is not installed.
| except Exception: | |
| pass | |
| except ImportError: | |
| pass |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
No description provided.