Conversation
Signed-off-by: root <zhangyuekai@foxmail.com>
Signed-off-by: Yuekai Zhang <zhangyuekai@foxmail.com>
📝 WalkthroughWalkthroughThis PR adds support for the Qwen2.5 Omni multimodal model to Megatron-Bridge. It introduces a complete model implementation including transformer configurations, HuggingFace-to-Megatron bridge infrastructure, a model provider, multimodal RoPE position encoding utilities, and comprehensive unit tests. Changes
Sequence DiagramsequenceDiagram
actor User
participant Bridge as Qwen25OmniBridge
participant Provider as Qwen25OmniModelProvider
participant Model as Qwen25OmniModel
participant ThinkerModel as Qwen25OmniThinkerModel
participant VisionEnc as Vision Encoder<br/>(HF)
participant AudioEnc as Audio Encoder<br/>(HF)
participant LanguageModel as Language Model<br/>(Megatron)
User->>Bridge: Load HF model
Bridge->>Provider: provider_bridge()
activate Provider
Provider->>Model: Instantiate with configs
activate Model
Model->>ThinkerModel: Initialize thinker
activate ThinkerModel
ThinkerModel->>VisionEnc: Initialize vision encoder
ThinkerModel->>AudioEnc: Initialize audio encoder
ThinkerModel->>LanguageModel: Initialize language model
deactivate ThinkerModel
deactivate Model
deactivate Provider
User->>Model: forward(input_ids, pixel_values,<br/>input_features, ...)
activate Model
Model->>ThinkerModel: forward()
activate ThinkerModel
ThinkerModel->>VisionEnc: Encode images/videos
activate VisionEnc
VisionEnc-->>ThinkerModel: vision_embeddings
deactivate VisionEnc
ThinkerModel->>AudioEnc: get_audio_features()
activate AudioEnc
AudioEnc-->>ThinkerModel: audio_embeddings
deactivate AudioEnc
ThinkerModel->>LanguageModel: Get text embeddings
activate LanguageModel
LanguageModel-->>ThinkerModel: text_embeddings
deactivate LanguageModel
ThinkerModel->>ThinkerModel: Substitute vision/audio<br/>embeddings at token positions
ThinkerModel->>ThinkerModel: Compute 3D RoPE<br/>position_ids
ThinkerModel->>LanguageModel: forward(embeddings,<br/>position_ids, masks, ...)
activate LanguageModel
LanguageModel-->>ThinkerModel: logits/output
deactivate LanguageModel
ThinkerModel-->>Model: output
deactivate ThinkerModel
Model-->>User: logits
deactivate Model
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 11
🧹 Nitpick comments (1)
src/megatron/bridge/models/qwen_omni/qwen25_omni_provider.py (1)
24-24: Replacetyping.Listwith built-in generic syntax.Use
list[int]formrope_sectionand remove theListimport.Proposed fix
-from typing import List @@ - mrope_section: List[int] = field(default_factory=lambda: [16, 24, 24]) + mrope_section: list[int] = field(default_factory=lambda: [16, 24, 24])As per coding guidelines, use built-in generics (
list,dict,tuple) instead of typing equivalents.Also applies to: 79-79
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/models/qwen_omni/qwen25_omni_provider.py` at line 24, Replace the typing.List usages with built-in generics: remove the import "from typing import List" and change any annotations that use List (notably the mrope_section annotation) to use the built-in form (e.g., list[int]); update any other occurrences in this module that reference List (such as the other annotation around qwen25_omni provider functions/variables) to their equivalent built-in generic types.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/megatron/bridge/models/qwen_omni/__init__.py`:
- Around line 20-24: The __all__ export list in this module is unsorted; update
the __all__ variable so its entries are in deterministic alphabetical order
(e.g., ensure "Qwen25OmniBridge", "Qwen25OmniModel", "Qwen25OmniModelProvider"
are sorted) to satisfy Ruff RUF022; locate the __all__ definition containing
"Qwen25OmniModel", "Qwen25OmniBridge", and "Qwen25OmniModelProvider" and reorder
the entries alphabetically (or generate the list with sorted(...)) so the lint
warning is resolved.
In `@src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/model.py`:
- Around line 43-44: Constructor parameters talker_transformer_config and
token2wav_transformer_config are declared but unused; either remove them from
the constructor signature or explicitly retain them by assigning to instance
attributes (e.g., self.talker_transformer_config = talker_transformer_config and
self.token2wav_transformer_config = token2wav_transformer_config) or, if
intentionally unused, prefix them with an underscore (e.g.,
_talker_transformer_config) or add an explicit comment/annotation to silence
ARG002; update the __init__ of the model class in modeling_qwen25_omni to
reflect the chosen approach so intent is clear.
- Line 50: The nullable type hints in the class constructor (__init__) and the
forward method should use the explicit union syntax `T | None` instead of bare
defaults of `None`; update the annotations for parameters such as pg_collection
(in __init__) and in forward: position_ids, attention_mask, labels, loss_mask,
inference_params, packed_seq_params, extra_block_kwargs, pixel_values,
pixel_values_videos, image_grid_thw, video_grid_thw, image_input_mask,
video_input_mask, cp_img_num, and images_padded to use the form e.g.
torch.Tensor | None, dict | None, list[int] | None, list[bool] | None, etc.,
keeping their default values as None and preserving existing parameter names and
semantics.
In `@src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/rope.py`:
- Line 114: The loop currently shadows the function argument input_ids by using
"for i, input_ids in enumerate(total_input_ids):" which makes later references
(e.g., device handling that expects the outer input_ids) ambiguous; rename the
loop variable to something like batch_input_ids (e.g., "for i, batch_input_ids
in enumerate(total_input_ids):") and update all usages inside that loop to use
batch_input_ids, while ensuring any code that should reference the original
function argument input_ids (such as the device resolution/handling code later
in the function) continues to reference the outer input_ids variable.
In `@src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/thinker_model.py`:
- Around line 176-181: The get_audio_features method currently assumes
feature_attention_mask is non-null and ignores audio_feature_lengths; update it
to first use audio_feature_lengths if provided (use that as lengths and
construct/derive a matching feature_attention_mask if needed), otherwise if
feature_attention_mask is None compute feature lengths from input_features
(e.g., full-length) or safely skip mask.sum(-1); guard every place that calls
feature_attention_mask.sum(-1) (including the later block around lines 185-190)
with a conditional so you only sum when feature_attention_mask is not None, and
ensure the returned attention mask and length values reflect the caller-provided
audio_feature_lengths when present.
- Around line 62-63: The constructor of ThinkerModel accepts pg_collection:
ProcessGroupCollection = None but immediately dereferences it; update the
__init__ (and the other places where pg_collection is used) to guard against
None by either early-returning/skipping process-group setup when pg_collection
is None or by creating a default ProcessGroupCollection instance; specifically,
wrap usages of pg_collection (e.g., any calls like pg_collection.get_or_create,
pg_collection.add, pg_collection.create_process_group) in an if pg_collection is
not None: ... block or assign a fallback local variable before dereference so no
attribute access occurs when pg_collection is None.
- Around line 224-225: Replace the two assert checks for unsupported modes with
explicit raises of NotImplementedError so they cannot be bypassed with Python
-O; specifically change the checks that reference inference_params and
packed_seq_params in thinker_model.py (the assertions asserting inference_params
is None and packed_seq_params is None) to raise NotImplementedError with the
same descriptive messages ("not support inference" and "not support
packed_seq_params") to fail fast and clearly indicate unsupported functionality.
In
`@src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/transformer_config.py`:
- Line 16: Remove the typing.List import and replace all uses of the typing
generic with the built-in generic syntax: remove the line "from typing import
List" and change every annotation like "List[int]" to "list[int]" (also update
any other List[...] occurrences in this module, e.g., the annotation referenced
near the other occurrence). Ensure type hints across transformer_config.py use
built-in generics (list, dict, tuple) and run the type-checker to confirm no
residual imports of typing.List remain.
In `@src/megatron/bridge/models/qwen_omni/qwen25_omni_bridge.py`:
- Line 87: The current expression mrope_section=getattr(text_config,
"rope_scaling", {}).get("mrope_section", [16, 24, 24]) can raise AttributeError
if text_config.rope_scaling exists but is None; change the guard so you first
retrieve rope_scaling (e.g., rope_scaling = getattr(text_config, "rope_scaling",
None)) and then call .get on a safe dict (e.g., (rope_scaling or
{}).get("mrope_section", [16,24,24])) or use an explicit conditional to set
mrope_section; update the assignment in qwen25_omni_bridge.py to use this safe
lookup for rope_scaling.
In `@tests/unit_tests/models/qwen_omni/modeling_qwen25_omni/test_omni_model.py`:
- Around line 204-205: get_data_batch currently uses undefined locals
random_video and random_audio causing NameError; update the helper
(get_data_batch) to accept random_video and random_audio as parameters or create
them inside the function (similar to random_image) so all referenced variables
are defined; ensure any tests calling get_data_batch (and related calls around
lines referencing this helper) are updated to pass the new args if you choose
parameters, and keep the function signature consistent across usages.
- Around line 40-49: The fixtures processor and hf_config currently call
AutoProcessor.from_pretrained and AutoConfig.from_pretrained which perform
network downloads; change these tests to avoid runtime downloads by either (a)
replacing the fixtures with lightweight local test doubles (e.g., a simple stub
object implementing the minimal interface used in tests) or (b) monkeypatching
AutoProcessor.from_pretrained and AutoConfig.from_pretrained to return mocked
instances; update the processor and hf_config fixtures to return those local
stubs/mocks (or use pytest monkeypatch in the module-level fixtures) so tests
run offline and no external model artifact is fetched.
---
Nitpick comments:
In `@src/megatron/bridge/models/qwen_omni/qwen25_omni_provider.py`:
- Line 24: Replace the typing.List usages with built-in generics: remove the
import "from typing import List" and change any annotations that use List
(notably the mrope_section annotation) to use the built-in form (e.g.,
list[int]); update any other occurrences in this module that reference List
(such as the other annotation around qwen25_omni provider functions/variables)
to their equivalent built-in generic types.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c82075a4-ee37-4efc-9062-4b0ad957220d
📒 Files selected for processing (11)
src/megatron/bridge/models/__init__.pysrc/megatron/bridge/models/conversion/auto_bridge.pysrc/megatron/bridge/models/qwen_omni/__init__.pysrc/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/__init__.pysrc/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/model.pysrc/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/rope.pysrc/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/thinker_model.pysrc/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/transformer_config.pysrc/megatron/bridge/models/qwen_omni/qwen25_omni_bridge.pysrc/megatron/bridge/models/qwen_omni/qwen25_omni_provider.pytests/unit_tests/models/qwen_omni/modeling_qwen25_omni/test_omni_model.py
src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/model.py
Outdated
Show resolved
Hide resolved
src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/rope.py
Outdated
Show resolved
Hide resolved
src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/thinker_model.py
Outdated
Show resolved
Hide resolved
src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/thinker_model.py
Outdated
Show resolved
Hide resolved
src/megatron/bridge/models/qwen_omni/modeling_qwen25_omni/transformer_config.py
Outdated
Show resolved
Hide resolved
tests/unit_tests/models/qwen_omni/modeling_qwen25_omni/test_omni_model.py
Show resolved
Hide resolved
tests/unit_tests/models/qwen_omni/modeling_qwen25_omni/test_omni_model.py
Outdated
Show resolved
Hide resolved
Signed-off-by: root <zhangyuekai@foxmail.com>
|
@yaoyu-33 Would you mind helping reviewing the PR? Thanks. |
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
Release Notes