-
Notifications
You must be signed in to change notification settings - Fork 81
Rebase #514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rebase #514
Conversation
Co-authored-by: zhaochenyang20 <zhaochenyang20@gmail.com> Co-authored-by: PopSoda2002 <zhouhp.me@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: zijiexia <zijie_xia@icloud.com> Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
Co-authored-by: Ratish1 <formula733@gmail.com>"
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
…tron & FSDP alignment (radixark#412)
…4 training, bug fix, etc. (radixark#426)
Co-authored-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu> Co-authored-by: Jiajun Li <guapisolo@gmail.com>
Co-authored-by: Yusheng Su <radixark@ac-h200-user-3.tail134ba0.ts.net>
Co-authored-by: Ethan (Yusheng) Su <yushengsu.thu@gmail.com>
Summary of ChangesHello @xiuhu17, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces several key updates and new features to the repository. It focuses on improving build reproducibility, enhancing attention mechanisms, adding support for FSDP training, and providing more customization options. The changes include updates to build scripts, Dockerfiles, code implementations, and documentation, all aimed at improving the performance, stability, and flexibility of the system. Highlights
Ignored Files
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request appears to be a major rebase, incorporating a wide range of new features, refactoring, and dependency updates. The changes are extensive, touching many parts of the codebase.
Key improvements include the introduction of a unified training loop structure, better support for FSDP, and new features like INT4 training, on-policy distillation, and integrations with Terminal Bench and SWE-agent. The refactoring efforts, such as creating a ParallelState and centralizing data processing logic, are commendable and should improve maintainability.
However, there is a significant concern regarding the monkey-patching of torch.distributed._shard.sharded_tensor.ShardedTensor in miles/backends/megatron_utils/checkpoint.py to disable validation. While this might be a necessary performance optimization for now, it is a risky hack that could lead to subtle bugs and become a maintenance burden. It would be preferable to find a solution through official APIs or contribute an improvement upstream to PyTorch.
I've also identified a potential regression where speculative decoding metrics might no longer be logged. Please see the specific comment for details.
| log_dict = {} | ||
| log_dict |= dict_add_prefix(compute_statistics(response_lengths), "response_len/") | ||
| log_dict |= _compute_zero_std_metrics(args, samples) | ||
| log_dict |= _compute_spec_metrics(args, samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # mv /usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/libnccl.so.2 /usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/libnccl.so.2.bak && \ | ||
| # cp -r third_party/nccl/build/lib/* /usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/ | ||
|
|
||
| RUN [ ! -f /root/.tmux.conf ] || rm /root/.tmux.conf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and wide-ranging set of changes, primarily focused on refactoring the training backends for better modularity and adding extensive new features. Key improvements include the introduction of a ParallelState abstraction to unify parallelism logic, major refactoring of the FSDP and Megatron actors, and the addition of fault tolerance for rollout engines. The PR also adds support for new models like DeepSeek-V3.2, new agent frameworks such as strands-sglang and swe-agent, and enhances low-precision training capabilities. The documentation has been substantially updated to reflect these new features. My review identifies one area of concern regarding a monkey-patch for performance optimization, which could pose a future maintenance risk.
| from miles.utils import megatron_bridge_utils | ||
|
|
||
| try: | ||
| # Here we patch out the `validate_non_overlapping_shards_metadata` in both functions | ||
| # because it is really slow for large models with many shards. | ||
| # TODO: find a less hacky way to do this. | ||
| import torch.distributed as dist | ||
| import torch.distributed._shard.sharding_spec as shard_spec | ||
| from torch.distributed._shard.sharded_tensor import ShardedTensor | ||
| from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata | ||
| from torch.distributed._shard.sharded_tensor.shard import Shard | ||
| from torch.distributed._shard.sharded_tensor.utils import _parse_and_validate_remote_device | ||
| from torch.distributed._shard.sharding_spec.api import EnumerableShardingSpec | ||
|
|
||
| def __post_init__(self): | ||
| pass | ||
|
|
||
| EnumerableShardingSpec.__post_init__ = __post_init__ | ||
|
|
||
| @classmethod | ||
| def _init_from_local_shards_and_global_metadata( # type: ignore[override] | ||
| cls, | ||
| local_shards: list[Shard], | ||
| sharded_tensor_metadata: ShardedTensorMetadata, | ||
| process_group=None, | ||
| init_rrefs=False, | ||
| sharding_spec=None, | ||
| ) -> ShardedTensor: | ||
| """ | ||
| Initialize a ShardedTensor with local shards and a global | ||
| ShardedTensorMetadata built on each rank. | ||
| Warning: This API is experimental and subject to change. It does | ||
| not do cross rank validations, and fully rely on the user | ||
| for the correctness of sharded_tensor_metadata on each rank | ||
| """ | ||
| process_group = cls._normalize_pg(process_group) | ||
| current_rank = dist.get_rank() # intentional to get global rank | ||
|
|
||
| shards_metadata = sharded_tensor_metadata.shards_metadata | ||
|
|
||
| local_shard_metadatas = [] | ||
|
|
||
| # collect local shard metadatas from the global sharded_tensor_metadata | ||
| for shard_metadata in shards_metadata: # type: ignore[attr-defined] | ||
| rank, local_device = _parse_and_validate_remote_device(process_group, shard_metadata.placement) | ||
|
|
||
| if current_rank == rank: | ||
| local_shard_metadatas.append(shard_metadata) | ||
|
|
||
| shards_metadata = sharded_tensor_metadata.shards_metadata | ||
| tensor_properties = sharded_tensor_metadata.tensor_properties | ||
|
|
||
| if sharding_spec is None: | ||
| spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) | ||
| else: | ||
| spec = sharding_spec | ||
|
|
||
| sharded_tensor = ShardedTensor.__new__( | ||
| ShardedTensor, | ||
| spec, | ||
| sharded_tensor_metadata.size, | ||
| dtype=tensor_properties.dtype, | ||
| layout=tensor_properties.layout, | ||
| pin_memory=tensor_properties.pin_memory, | ||
| requires_grad=tensor_properties.requires_grad, | ||
| ) | ||
|
|
||
| # done validation, add local_shards | ||
| sharded_tensor._local_shards = local_shards | ||
| sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) | ||
|
|
||
| # run post initialization, i.e. map registration, rpc initialization | ||
| sharded_tensor._post_init() | ||
| return sharded_tensor | ||
|
|
||
| ShardedTensor._init_from_local_shards_and_global_metadata = _init_from_local_shards_and_global_metadata | ||
|
|
||
| except ImportError: | ||
| pass | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This large try...except block monkey-patches PyTorch's ShardedTensor and EnumerableShardingSpec to bypass a performance-intensive validation step. While the performance gain might be necessary for large models, this approach is brittle and poses a significant maintenance risk. It's likely to break with future PyTorch updates.
To mitigate this risk, consider the following:
- Add PyTorch version checks: Gate this patch to specific PyTorch versions that are known to be compatible. This will prevent silent failures or unexpected behavior when the library is updated.
- Improve error handling: Instead of a silent
except ImportError: pass, log a warning if the patching fails. This would make it clear that the performance optimization is not being applied. - Upstream the issue: If this is a general performance problem in PyTorch's distributed checkpointing, it would be best to report it to the PyTorch team. They might provide a proper API to disable this validation or offer a more efficient implementation in the future.
The TODO comment indicates awareness of the issue, but strengthening the implementation with version checks and better error handling would make this less risky.
No description provided.