-
Notifications
You must be signed in to change notification settings - Fork 81
Implement Blackwell MXFP8 recipe #512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @zianglih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new Python utility script designed to convert existing Hugging Face models, specifically those stored in Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new script for converting Hugging Face models to MXFP8 format. The script is well-structured and covers the main steps of conversion, including quantizing weights, handling non-quantized layers, and updating model configuration files. My review includes suggestions to improve robustness by using context managers for file operations, enhancing efficiency by using appropriate data structures, and improving code clarity and maintainability. I've also pointed out a potential issue with dynamic path modification for imports.
|
|
||
| config_path = os.path.join(input_path, "config.json") | ||
| if os.path.exists(config_path): | ||
| cfg = json.load(open(config_path)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opening files without a with statement can lead to resource leaks if an exception occurs before the file is closed. It's best practice to use a with block to ensure files are always closed correctly. This applies here and on lines 161 and 167.
For example, this line could be rewritten as:
with open(config_path, 'r') as f:
cfg = json.load(f)| except ImportError: | ||
| repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) | ||
| sglang_path = os.path.join(repo_root, "sglang", "python") | ||
| if sglang_path not in sys.path: | ||
| sys.path.append(sglang_path) | ||
| from sglang.srt.layers.quantization.fp8_utils import mxfp8_group_quantize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dynamically modifying sys.path can make dependency management fragile and less explicit. It's generally better to rely on standard Python packaging practices. Consider instructing users to install the project in editable mode (e.g., pip install -e .) or to set their PYTHONPATH environment variable. If you must keep this dynamic import, consider adding a comment explaining why it's necessary and how it works.
| if k % 32 != 0: | ||
| raise ValueError(f"Last dim {k} must be divisible by 32 for MXFP8.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check for divisibility by 32 is already performed in the should_quantize function before this function is called. This makes the if condition and ValueError redundant in the current control flow. To maintain this as a precondition for the function, consider changing it to an assert. This documents the assumption without adding runtime overhead in production builds (if assertions are disabled) and makes the code's intent clearer.
assert k % 32 == 0, f"Last dim {k} must be divisible by 32 for MXFP8."
tools/convert_hf_to_mxfp8.py
Outdated
| self.modules_to_not_convert: List[str] = [] | ||
|
|
||
| def add_result(self, filename: str, q_weights: Dict[str, torch.Tensor], module_names: List[str]) -> None: | ||
| for key, tensor in q_weights.items(): | ||
| self.weight_map[key] = filename | ||
| self.total_size += tensor.numel() * tensor.element_size() | ||
| self.modules_to_not_convert.extend(module_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The modules_to_not_convert attribute is a list that can grow large and contain duplicates, which are then removed later using list(set(...)). It's more efficient to use a set from the start to store unique module names. This avoids unnecessary memory usage and makes the intention clearer.
Note: You'll need to import Set from typing.
After this change, you'll also need to update line 155 to quantization_config['modules_to_not_convert'] = sorted(list(result_collector.modules_to_not_convert)) to ensure deterministic output.
self.modules_to_not_convert: Set[str] = set()
def add_result(self, filename: str, q_weights: Dict[str, torch.Tensor], module_names: List[str]) -> None:
for key, tensor in q_weights.items():
self.weight_map[key] = filename
self.total_size += tensor.numel() * tensor.element_size()
self.modules_to_not_convert.update(module_names)
@HumansAnd
WIP. This PR depends on sgl-project/sglang#17449