-
Notifications
You must be signed in to change notification settings - Fork 148
Refactor z-image-turbo inference code #832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @gushiqiao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on refactoring the z-image-turbo inference code to improve memory efficiency and performance. The changes include adding support for FP8 quantization, enabling CPU offloading, and optimizing tensor processing for single-sample inference. The removal of custom Triton kernels and integration of new attention and FFN implementations aim to simplify the codebase and enhance maintainability. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant refactoring of the Z-Image inference pipeline, primarily to support lazy loading and CPU offloading at a finer granularity (block level). The changes streamline tensor operations by removing explicit batch dimensions in intermediate processing steps within pre_infer.py, post_infer.py, and transformer_infer.py. Transformer blocks are now more modular, organized into 'compute phases' (AdaLN modulation, attention, FFN), which is reflected in the transformer_weights.py file. The removal of triton_ops.py indicates a shift away from custom Triton kernels. Additionally, support for sequence parallelism has been explicitly removed, and the model converter tool has been updated to recognize the new Z-Image model type. The runner logic is also adjusted to handle lazy loading of sub-modules. Overall, these changes aim to improve memory efficiency and flexibility for offloading, while simplifying the internal data flow for single-sample inference.
| ) | ||
| else: | ||
| hidden_states_out = block_weight.attention.calculate.apply( | ||
| # todo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") | ||
| else: | ||
| self.seq_p_group = None | ||
| raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NotImplementedError for seq_parallel indicates that this feature is currently unsupported. If seq_parallel is a critical feature for this model, this should be addressed. If it's intentionally deprecated or not planned, a more descriptive comment or a configuration flag to disable it might be appropriate.
| pre_infer_out.hidden_states = F.pad(pre_infer_out.hidden_states, (0, 0, 0, padding_size)) | ||
| pre_infer_out.hidden_states = torch.chunk(pre_infer_out.hidden_states, world_size, dim=1)[cur_rank] | ||
| return pre_infer_out | ||
| raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NotImplementedError for _seq_parallel_pre_process indicates that this feature is currently unsupported. If seq_parallel is a critical feature for this model, this should be addressed. If it's intentionally deprecated or not planned, a more descriptive comment or a configuration flag to disable it might be appropriate.
| dist.all_gather(gathered_noise_pred, noise_pred, group=self.seq_p_group) | ||
| noise_pred = torch.cat(gathered_noise_pred, dim=1) | ||
| return noise_pred | ||
| raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NotImplementedError for _seq_parallel_post_process indicates that this feature is currently unsupported. If seq_parallel is a critical feature for this model, this should be addressed. If it's intentionally deprecated or not planned, a more descriptive comment or a configuration flag to disable it might be appropriate.
| self.attn_type = config.get("attn_type", "flash_attn3") | ||
| self.heads = config["n_heads"] | ||
| self.rms_norm_type = config.get("rms_norm_type", "sgl-kernel") | ||
| self.rms_norm_type = config.get("rms_norm_type", "one-pass") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| dim = config["dim"] | ||
| hidden_dim = int(dim / 3 * 8) # FeedForward hidden_dim = dim / 3 * 8 | ||
|
|
||
| self.rms_norm_type = config.get("rms_norm_type", "one-pass") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
No description provided.