Skip to content

Fix sequence padding for DiT. Support CP + THD for DiT#39

Merged
abhinavg4 merged 20 commits intomainfrom
fix_dit_cp
Nov 17, 2025
Merged

Fix sequence padding for DiT. Support CP + THD for DiT#39
abhinavg4 merged 20 commits intomainfrom
fix_dit_cp

Conversation

@sajadn
Copy link
Contributor

@sajadn sajadn commented Nov 13, 2025

  • Add padding between samples instead of in the end
  • Enable context parallel for DiT + THD

…th THD.

Signed-off-by: sajadn <snorouzi@nvidia.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 13, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

abhinavg4 and others added 19 commits November 13, 2025 22:58
- Updated `get_query_key_value_tensors` method in `dit_attention.py` to include an `output_gate` parameter and set `split_qkv` to default to `True`.
- Modified `WanLayerWithAdaLN` class in `wan_layer_spec.py` to add `rotary_pos_cos_sin` parameter for improved positional encoding handling.
- Added initialization of `pg_collection` in both `DiTCrossAttentionModel` and `WanModel` to ensure proper handling of process groups.
- This change checks if `pg_collection` exists and is not None before assigning it, enhancing the robustness of the models.
…elopment environment and Docker container usage. Added sections for building and running the container, as well as setting the PYTHONPATH for DFM.
…. Removed redundant import of ProcessGroupCollection, enhancing code clarity and maintainability.
- Updated string quotes in `dit_model.py` and `wan_model.py` for consistency, changing from single to double quotes.
- Reformatted the `get_query_key_value_tensors` method call in `dit_attention.py` for improved readability by breaking it into multiple lines.
…th THD.

Signed-off-by: sajadn <snorouzi@nvidia.com>
Signed-off-by: sajadn <snorouzi@nvidia.com>
…dge.

Signed-off-by: Sajad Norouzi <snorouzi@nvidia.com>
Signed-off-by: sajadn <snorouzi@nvidia.com>
…st command in GPU mock tests. Added a new test file for DiT pretraining and modified the existing GPU test script to run all tests in the recipes directory.
@abhinavg4
Copy link
Contributor

/ok to test 33bfbbe

Copy link
Contributor

@abhinavg4 abhinavg4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a small comment

)
else:
self.cross_attention = None
self.cross_attention = build_module(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that leave the above stuff commented in case any customer wants to explore it. In any case please remove the comments above if you are removing the code.

seq_length=2048,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
task_encoder_seq_length=8000,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally all these hardcoded stuff goes into a config but it's ok since I know we do the same in Wan and many other palces in the code.

@abhinavg4 abhinavg4 merged commit 5bb89f5 into main Nov 17, 2025
16 checks passed
linnanwang pushed a commit that referenced this pull request Nov 17, 2025
* Fix sequence padding for DiT. Add support for DiT Context Parallel with THD.

Signed-off-by: sajadn <snorouzi@nvidia.com>

* Enhance DiT and Wan layer specifications

- Updated `get_query_key_value_tensors` method in `dit_attention.py` to include an `output_gate` parameter and set `split_qkv` to default to `True`.
- Modified `WanLayerWithAdaLN` class in `wan_layer_spec.py` to add `rotary_pos_cos_sin` parameter for improved positional encoding handling.

* Implement ProcessGroupCollection initialization in DiT and Wan models

- Added initialization of `pg_collection` in both `DiTCrossAttentionModel` and `WanModel` to ensure proper handling of process groups.
- This change checks if `pg_collection` exists and is not None before assigning it, enhancing the robustness of the models.

* Update CONTRIBUTING.md to include detailed setup instructions for development environment and Docker container usage. Added sections for building and running the container, as well as setting the PYTHONPATH for DFM.

* Refactor import statements in dit_model.py to streamline dependencies. Removed redundant import of ProcessGroupCollection, enhancing code clarity and maintainability.

* Refactor code style in DiT and Wan models

- Updated string quotes in `dit_model.py` and `wan_model.py` for consistency, changing from single to double quotes.
- Reformatted the `get_query_key_value_tensors` method call in `dit_attention.py` for improved readability by breaking it into multiple lines.

* Revert M4 changes

* Ruff

* Ruff

* Lint

* Fix sequence padding for DiT. Add support for DiT Context Parallel with THD.

Signed-off-by: sajadn <snorouzi@nvidia.com>

* fix cp inference. add cu_seqlen_kv_padded which was missing.

Signed-off-by: sajadn <snorouzi@nvidia.com>

* Add mock DiT dataset. Make DiT attention compatible with megatron bridge.

Signed-off-by: Sajad Norouzi <snorouzi@nvidia.com>

* fix checkpoint loading issue.

Signed-off-by: sajadn <snorouzi@nvidia.com>

* Implement functional smoke tests for Mcore DiT pretrain and update test command in GPU mock tests. Added a new test file for DiT pretraining and modified the existing GPU test script to run all tests in the recipes directory.

---------

Signed-off-by: sajadn <snorouzi@nvidia.com>
Signed-off-by: Sajad Norouzi <snorouzi@nvidia.com>
Co-authored-by: Abhinav Garg <abhinavg@stanford.edu>
huvunvidia pushed a commit that referenced this pull request Feb 12, 2026
* Fix sequence padding for DiT. Add support for DiT Context Parallel with THD.

Signed-off-by: sajadn <snorouzi@nvidia.com>

* Enhance DiT and Wan layer specifications

- Updated `get_query_key_value_tensors` method in `dit_attention.py` to include an `output_gate` parameter and set `split_qkv` to default to `True`.
- Modified `WanLayerWithAdaLN` class in `wan_layer_spec.py` to add `rotary_pos_cos_sin` parameter for improved positional encoding handling.

* Implement ProcessGroupCollection initialization in DiT and Wan models

- Added initialization of `pg_collection` in both `DiTCrossAttentionModel` and `WanModel` to ensure proper handling of process groups.
- This change checks if `pg_collection` exists and is not None before assigning it, enhancing the robustness of the models.

* Update CONTRIBUTING.md to include detailed setup instructions for development environment and Docker container usage. Added sections for building and running the container, as well as setting the PYTHONPATH for DFM.

* Refactor import statements in dit_model.py to streamline dependencies. Removed redundant import of ProcessGroupCollection, enhancing code clarity and maintainability.

* Refactor code style in DiT and Wan models

- Updated string quotes in `dit_model.py` and `wan_model.py` for consistency, changing from single to double quotes.
- Reformatted the `get_query_key_value_tensors` method call in `dit_attention.py` for improved readability by breaking it into multiple lines.

* Revert M4 changes

* Ruff

* Ruff

* Lint

* Fix sequence padding for DiT. Add support for DiT Context Parallel with THD.

Signed-off-by: sajadn <snorouzi@nvidia.com>

* fix cp inference. add cu_seqlen_kv_padded which was missing.

Signed-off-by: sajadn <snorouzi@nvidia.com>

* Add mock DiT dataset. Make DiT attention compatible with megatron bridge.

Signed-off-by: Sajad Norouzi <snorouzi@nvidia.com>

* fix checkpoint loading issue.

Signed-off-by: sajadn <snorouzi@nvidia.com>

* Implement functional smoke tests for Mcore DiT pretrain and update test command in GPU mock tests. Added a new test file for DiT pretraining and modified the existing GPU test script to run all tests in the recipes directory.

---------

Signed-off-by: sajadn <snorouzi@nvidia.com>
Signed-off-by: Sajad Norouzi <snorouzi@nvidia.com>
Co-authored-by: Abhinav Garg <abhinavg@stanford.edu>
@wplf
Copy link

wplf commented Mar 5, 2026

Hi, thank you for great works.

May I ask if you test loss convergence for CP=1 and CP=2?
Can I see the loss curve?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants