Skip to content

fix ci test#51

Closed
meichangsu1 wants to merge 15 commits intodevfrom
sp_ljl_dev
Closed

fix ci test#51
meichangsu1 wants to merge 15 commits intodevfrom
sp_ljl_dev

Conversation

@meichangsu1
Copy link
Copy Markdown
Collaborator

No description provided.

Add a new test file `test_sequence_parallel_single_attention.py` to verify the correctness of the sequence parallel attention implementation. The test includes a distributed setup using torch.distributed and compares outputs between sequence parallel and local attention modes. Also adds an empty `__init__.py` to the transformers test directory for proper module imports.
- Add `_enable_strict_determinism` helper to disable TF32 and enable deterministic algorithms
- Add `_to_local` helper to unwrap DTensors for gradient comparison
- Update test to use full world size for sequence parallel group and increase head count
- Switch to float32 dtype for stricter numerical alignment
- Improve gradient comparison by cloning and unwrapping tensors
…free logic

- Replace HfConfigFactory utility with direct get_config_attr function
- Move get_llm_model to shared transformers utilities
- Remove padding_free parameter and related conditional logic
- Simplify attention mask construction for padded tokens
- Update SequenceParallelConfig to drop padding_free field
- Add detection of packed batches via `_is_packed_position_ids` heuristic
- Raise RuntimeError when SDPA backend is used with packed batches, as SDPA lacks native packed/varlen support
- Build 2D attention_mask for padded sequences to ensure correct FlashAttention2 unpad behavior
- Avoid unnecessary 4D causal mask generation for packed/padding-free batches
Introduce a new cookbook script demonstrating supervised fine-tuning with a single controller using sequence parallelism (SP) and FSDP across 4 GPUs. The example includes:
- Device mesh configuration with dp=2 and fsdp=2 dimensions
- PackingDataset setup with self-cognition data and left truncation
- Training loop with LoRA adapter, AdamW optimizer, and periodic evaluation
- Checkpoint saving based on loss improvement
- Validation of FSDP + SP input slicing across multiple GPUs
…formers cookbook

- Add new single_controller_sp.py example demonstrating FSDP + SP validation over 4 GPUs
- Move legacy single_controller_sp.py to transformers/sp_fsdp_dense.py
- Add shell script sp_fsdp_dense.sh for running the example
- Update imports and structure to use twinkle framework components
…irectory

Relocate test_sequence_parallel_single_attention.py from tests/transformers/ to tests/sequence_parallel/ to better organize test files by feature area. This improves maintainability and aligns with the project's test structure conventions.
- Add bash script header and comments to `sp_fsdp_dense.sh` explaining how to enable sequence parallelism with ulysses_size
- Remove duplicate `import os` statement in transformers.py for cleaner code
- Fix minor formatting by removing extra blank line in transformers_utils.py
- Switch from `ray` to `local` mode for twinkle initialization
- Add evaluation function with separate dataset slice
- Increase dataset size from 100 to 500 samples
- Add cosine warmup learning rate scheduler
- Remove unused torch import and remote_group parameters
- Adjust batch size from 4 to 8 and logging frequency to every 20 steps
- Improve logging with train configs and total steps information
Removed unnecessary imports (`math`, `os`, `SimpleNamespace`) from the sequence_parallel strategy file to clean up the codebase and improve maintainability.
- Use `sequence_parallel._sp_group` directly instead of calling `_get_sp_group_from_device_mesh`
- Simplifies test setup by relying on internal module state after `_setup_sp`
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @meichangsu1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refines the sequence parallelism strategy within the twinkle framework, making it more robust and adaptable. The changes address how data is processed for packed batches, improve the accuracy of loss calculation in distributed settings, and provide clearer model introspection. A new, dedicated test suite ensures the stability and correctness of these distributed training features, while updated examples demonstrate their proper usage.

Highlights

  • Sequence Parallelism Enhancements: Refactored the core sequence parallelism (SP) logic to improve handling of packed batches, ensuring correct cu_seqlens derivation for FlashAttention2 and proper label masking for cross-subsequence supervision. The padding_free parameter was removed, simplifying the configuration.
  • Loss Reduction Improvements: Enhanced the reduce_loss mechanism within the SP strategy to support configurable 'sum' and 'mean' reductions, providing more flexibility in how losses are aggregated across parallel ranks. The loss_reduction is now dynamically set based on the loss instance.
  • Model Utility Refinements: Introduced a more robust get_llm_model utility function that intelligently unwraps various parallel and PEFT wrappers to accurately locate the underlying language model, improving compatibility and maintainability.
  • New Test Suite for SP: Added a comprehensive test file (test_sequence_parallel_single_attention.py) to validate the sequence parallel attention implementation, covering scenarios with and without padding, and its integration with Fully Sharded Data Parallel (FSDP).
  • Example Script Updates: Updated the sp_fsdp_dense.py example to use 'local' mode for twinkle.initialize, adjusted dataset creation, refined batch sizing, and integrated a learning rate scheduler. A new shell script (sp_fsdp_dense.sh) was added to facilitate running this example.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • cookbook/transformers/sp_fsdp_dense.py
    • Removed unused torch import.
    • Added DATASETS constant for dataset path.
    • Updated twinkle.initialize to use 'local' mode and removed groups parameter.
    • Refactored eval function for clarity and updated dataset slicing.
    • Simplified create_dataset template configuration.
    • Adjusted train function's DataLoader batch size and removed remote_group from DataLoader and TransformersModel.
    • Integrated learning rate scheduler setup into the training process.
    • Streamlined the training loop's logging and checkpoint saving logic.
  • cookbook/transformers/sp_fsdp_dense.sh
    • Added a new script to execute the sp_fsdp_dense.py example with torchrun.
  • src/twinkle/model/transformers/strategy/sequence_parallel.py
    • Cleaned up unused imports (math, SimpleNamespace).
    • Centralized get_llm_model utility by importing it from twinkle.utils.transformers_utils.
    • Refined gradient scaling in _SequenceParallelLoss backward pass by removing world_size multiplication.
    • Implemented explicit cu_seqlens derivation for packed batches in FlashAttention2.
    • Added validation for SDPA backend with packed batches, raising an error if detected.
    • Removed padding_free configuration from SequenceParallel.prepare and SequenceParallelConfig.
    • Removed pad_and_split_mm_tokens method.
    • Introduced _is_packed_position_ids static method for robust packed batch detection.
    • Added is_packed tracking to self.extra_kwargs in pad_and_split_inputs and prepare_inputs.
    • Improved attention_mask generation and label handling for packed batches during input preparation.
    • Enhanced reduce_loss to support configurable 'sum' and 'mean' reductions.
    • Refactored postprocess_outputs to ensure outputs is dict-like and to trim gathered logits to original length.
  • src/twinkle/model/transformers/transformers.py
    • Dynamically configured loss_reduction for sequence parallel strategy based on the loss instance.
  • src/twinkle/utils/transformers_utils.py
    • Improved get_modules_to_not_convert return logic.
    • Introduced a comprehensive get_llm_model utility for unwrapping parallel and PEFT models and locating the LLM module.
  • tests/sequence_parallel/test_sequence_parallel_single_attention.py
    • Added a new test file with unit tests for sequence parallel attention.
    • Included helper functions for distributed setup and deterministic execution.
    • Defined a _SingleAttention module for testing.
    • Implemented _run_worker_single_attn to test sequence parallelism with and without padding, comparing SP and DP outputs and gradients.
    • Implemented _run_worker_single_attn_fsdp to test FSDP with sequence parallelism.
    • Added unittest.TestCase for test_single_attention, test_single_attention_padding, and test_single_attention_fsdp.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant enhancements and fixes related to sequence parallelism (SP) and FlashAttention2 integration. Key changes include refactoring the sequence_parallel.py module to remove unused imports, simplify config attribute retrieval, and improve handling of packed batches by introducing an is_packed flag. This flag is used to dynamically derive cu_seqlens for FlashAttention2 and to raise an error if SDPA is used with packed batches, as it's not supported. The padding_free concept was removed, and the reduce_loss method in the SP strategy was updated to correctly handle 'sum' and 'mean' reductions for global loss calculation, with a review comment highlighting and correcting an erroneous division by world_size. Additionally, the postprocess_outputs method was refined to ensure compatibility with dict-like ModelOutput containers and to trim gathered logits to their original sequence length. The transformers.py file was updated to dynamically set the loss_reduction in the SP strategy based on the loss instance. A new get_llm_model utility was added to transformers_utils.py for robust LLM module extraction, with review comments suggesting more specific exception handling (ImportError instead of broad Exception). The example script sp_fsdp_dense.py was updated to use 'local' mode for twinkle.initialize, streamline the training and evaluation loops, and add an LR scheduler. Finally, a new test file test_sequence_parallel_single_attention.py was added to validate SP functionality, including padding and FSDP integration, ensuring correctness of forward and backward passes.

def create_dataset(data_slice=None):
dataset = Dataset(
dataset_meta=DatasetMeta("ms://swift/self-cognition", data_slice=data_slice)
dataset_meta=DatasetMeta(DATASETS, data_slice=range(500))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The data_slice parameter of the create_dataset function is currently being ignored due to the hardcoded range(500). This can lead to unexpected behavior as calls from eval and train with different data_slice values will not have the intended effect. To fix this and restore the expected behavior, the function should use the data_slice parameter that is passed to it.

Suggested change
dataset_meta=DatasetMeta(DATASETS, data_slice=range(500))
dataset_meta=DatasetMeta(DATASETS, data_slice=data_slice)

Comment on lines +1006 to +1008
if sequence_parallel.world_size > 1:
out_metric = out.detach() / sequence_parallel.world_size
return out_metric + (out - out.detach())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The division by sequence_parallel.world_size when calculating out_metric seems incorrect. For 'sum' reduction, out represents the global sum of losses. Dividing this by world_size results in an average of sums, which is not a standard metric. To report the correct global loss for metric tracking, this division should likely be removed.

Suggested change
if sequence_parallel.world_size > 1:
out_metric = out.detach() / sequence_parallel.world_size
return out_metric + (out - out.detach())
if sequence_parallel.world_size > 1:
out_metric = out.detach()
return out_metric + (out - out.detach())

Comment on lines +1019 to +1021
if sequence_parallel.world_size > 1:
out_metric = out.detach() / sequence_parallel.world_size
return out_metric + (out - out.detach())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The division by sequence_parallel.world_size when calculating out_metric seems incorrect. When using 'mean' reduction, out should already represent the global mean loss. Dividing it again by the world size would lead to an incorrect metric value. This division should probably be removed to ensure the reported loss is accurate.

Suggested change
if sequence_parallel.world_size > 1:
out_metric = out.detach() / sequence_parallel.world_size
return out_metric + (out - out.detach())
if sequence_parallel.world_size > 1:
out_metric = out.detach()
return out_metric + (out - out.detach())

Comment on lines +152 to +153
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: can mask unexpected errors and complicate debugging. It is better to catch a more specific exception. In this case, ImportError would be more appropriate, as it's the expected error if the accelerate library is not installed.

Suggested change
except Exception:
pass
except ImportError:
pass

Comment on lines +161 to +162
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous block, using a broad except Exception: is not ideal as it can hide bugs. It would be safer and clearer to catch the specific ImportError that would occur if peft is not installed.

Suggested change
except Exception:
pass
except ImportError:
pass

@meichangsu1 meichangsu1 deleted the sp_ljl_dev branch February 13, 2026 09:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant