-
Notifications
You must be signed in to change notification settings - Fork 199
Description
Plan: Energon Dataloader Support for HF-Encoder VL Models
Context
PR #2440 adds Energon dataloader support for Qwen3-VL, which uses a Megatron-native vision encoder. The QwenVLTaskEncoder is tightly coupled to Qwen's architecture (image token expansion based on grid THW, Qwen-specific process_vision, Qwen2_5_VLVisualInputs container).
HF-encoder VL models use HuggingFace vision towers instead of Megatron-native encoders. These models handle image processing through their HF processor (AutoProcessor) and don't need the Qwen-specific token expansion logic. The target models are:
| Model | Vision Encoder | HF Processor | Notes |
|---|---|---|---|
| Gemma3-VL | SiglipVisionModel (HF) | Gemma3Processor | Fixed 256 tokens/image, bidirectional attention mask |
| Ministral3 | PixtralVisionModel (HF) | PixtralProcessor | Variable patch sizes |
| GLM-4.5V | Glm4vVisionModel (HF) | Glm4vProcessor | Needs grid THW + MRoPE (hybrid) |
Goal: Create a generic Energon task encoder that works for all HF-encoder VL models by delegating vision processing to the model's HF processor, then integrate it into each model's recipe.
Step 1: Create Generic Visual Inputs Container
File: src/megatron/bridge/training/utils/visual_inputs.py (modify existing)
Add a GenericVisualInputs dataclass alongside the existing Qwen2_5_VLVisualInputs:
@dataclass
class GenericVisualInputs:
"""Container for visual modality tensors from HF-encoder VL models.
Holds the superset of visual fields needed across different HF-encoder models.
Only non-None fields are passed to the model via normalized_for_model().
"""
pixel_values: Optional[torch.Tensor] = None
pixel_values_videos: Optional[torch.Tensor] = None
image_grid_thw: Optional[torch.Tensor] = None
video_grid_thw: Optional[torch.Tensor] = None
image_sizes: Optional[torch.Tensor] = None
def as_model_kwargs(self) -> dict[str, torch.Tensor]:
"""Return a mapping of non-None fields suitable for model forward kwargs."""
return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None}
def normalized_for_model(self) -> dict[str, torch.Tensor]:
"""Return non-None fields. No shape transformation needed for HF-encoder models."""
return self.as_model_kwargs()This works with the existing vlm_step.py code which iterates val.__dict__ to move tensors to CUDA and calls visual_inputs.normalized_for_model() to get model kwargs.
Step 2: Create Generic HF-Encoder Task Encoder
File: src/megatron/bridge/data/energon/hf_encoder_task_encoder.py (new file)
Design
The key insight: for HF-encoder models, the HF processor already handles both tokenization and image processing via apply_chat_template() and processor(). We don't need custom token expansion logic like Qwen3-VL. We just need to:
- Convert raw Energon samples into conversation format
- Run the HF processor to get
input_ids,pixel_values, etc. - Create loss masks (reuse the existing search-based pattern from
collate.py) - Package into batches with
GenericVisualInputs
Class Structure
class HFEncoderVLMTaskEncoder(DefaultTaskEncoder[ChatMLSample, HFEncoderTaskSample, HFEncoderTaskBatch, dict]):Reuses the existing ChatMLSample and cook_chatml_sample() from recipes/qwen_vl/data/energon/task_encoder.py (move shared code to common location or import).
Key Methods
__init__(processor, tokenizer, seq_length, visual_keys):
processor: HF AutoProcessor instance (handles images + text)tokenizer: HF tokenizer (for loss mask creation)seq_length: max sequence lengthvisual_keys: list of keys to extract from processor output as visual inputs (e.g.,["pixel_values"]for Gemma3/Ministral3,["pixel_values", "image_grid_thw"]for GLM-4.5V)
encode_sample(sample: ChatMLSample) -> HFEncoderTaskSample:
- Parse conversation from
sample.conversation(JSON string) - Convert to HF conversation format (same logic as
QwenVLTaskEncoderconversation parsing) - Process images: convert
sample.imgs(list of PIL images) to the format processor expects - Run
processor.apply_chat_template(conversation, tokenize=True, ...)to getinput_ids - Run
processor(images=images, ...)to getpixel_valuesand other vision tensors - Create loss mask using the search-based pattern from
collate.py:create_multiturn_loss_mask_by_search() - Return
HFEncoderTaskSamplewith all fields
batch(samples: List[HFEncoderTaskSample]) -> HFEncoderTaskBatch:
- Pad
input_idsto max length in batch - Stack/concatenate visual tensors (pixel_values, etc.)
- Create
attention_mask,position_ids,loss_mask,labels - Return
HFEncoderTaskBatch
encode_batch(batch: HFEncoderTaskBatch) -> dict:
- Convert to dict
- Wrap visual tensors in
GenericVisualInputs - Return dict ready for
vlm_step.py
Reused utilities (from existing code)
cook_chatml_sample()fromrecipes/qwen_vl/data/energon/task_encoder.py- sample parsingcreate_multiturn_loss_mask_by_search()fromdata/vlm_datasets/collate.py- loss maskingextract_skipped_token_ids()fromdata/vlm_datasets/token_utils.py- pad token identificationget_ltor_masks_and_position_ids()fromrecipes/qwen_vl/data/energon/task_encoder.py- attention mask/position IDsCooker/basic_sample_keysfrommegatron.energon- sample cooking
Data format
Same pickle-based webdataset format as Qwen3-VL (jpgs pickle field + json conversation field). The cook_chatml_sample() function already handles this.
Step 3: Refactor Shared Code
Move shared code from the Qwen3-VL task encoder to a common location so both encoders can use it.
File: src/megatron/bridge/data/energon/task_encoder_utils.py (new file)
Move these from recipes/qwen_vl/data/energon/task_encoder.py:
ChatMLSampledataclasscook_chatml_sample()functionget_ltor_masks_and_position_ids()functionfind_pattern_indices()functionIGNORE_INDEXconstant
The existing QwenVLTaskEncoder should then import from this common location (backward-compatible).
Step 4: Recipe Integration
Add dataset_type="energon" option to each HF-encoder VL model recipe:
4a. Gemma3-VL Recipe
File: src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py
Add an elif _dataset_choice == "energon": branch that:
- Creates
AutoProcessor.from_pretrained(hf_path) - Instantiates
HFEncoderVLMTaskEncoder(processor=processor, tokenizer=tokenizer, seq_length=seq_length, visual_keys=["pixel_values"]) - Creates
EnergonProvider(task_encoder=task_encoder, ...)
4b. Ministral3 Recipe
File: src/megatron/bridge/recipes/ministral3/ministral3.py
Same pattern. Uses PixtralProcessor via AutoProcessor.from_pretrained().
visual_keys=["pixel_values"]
4c. GLM-4.5V Recipe
File: src/megatron/bridge/recipes/glm_vl/glm_45v.py
Same pattern. Uses GLM processor.
visual_keys=["pixel_values", "pixel_values_videos", "image_grid_thw", "video_grid_thw"] (GLM needs grid info like Qwen)
Step 5: Fix Ministral3 Collate (side fix)
File: src/megatron/bridge/data/vlm_datasets/collate.py
The existing ministral3_collate_fn does NOT wrap visual tensors in a visual_inputs container (unlike qwen2_5_collate_fn and default_collate_fn). This means vlm_step.py would receive visual_inputs=None and drop the pixel values. Add the GenericVisualInputs wrapper to ministral3_collate_fn for consistency.
Files Modified/Created
| File | Action | Description |
|---|---|---|
src/megatron/bridge/training/utils/visual_inputs.py |
Modify | Add GenericVisualInputs class |
src/megatron/bridge/data/energon/task_encoder_utils.py |
New | Shared utilities (ChatMLSample, cook_chatml_sample, etc.) |
src/megatron/bridge/data/energon/hf_encoder_task_encoder.py |
New | Generic HF-encoder task encoder |
src/megatron/bridge/data/energon/__init__.py |
Modify | Export new task encoder |
src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py |
Modify | Add energon dataset_type |
src/megatron/bridge/recipes/ministral3/ministral3.py |
Modify | Add energon dataset_type |
src/megatron/bridge/recipes/glm_vl/glm_45v.py |
Modify | Add energon dataset_type |
src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py |
Modify | Import shared code from utils |
src/megatron/bridge/data/vlm_datasets/collate.py |
Modify | Fix ministral3 visual_inputs wrapping |
Testing Plan
Unit Tests
File: tests/unit_tests/data/energon/test_hf_encoder_task_encoder.py (new)
-
Test
GenericVisualInputs:normalized_for_model()returns only non-None fields- CUDA transfer works via
__dict__iteration - Works with
pixel_valuesonly (Gemma3/Ministral3 case) - Works with
pixel_values+image_grid_thw(GLM case)
-
Test
cook_chatml_sample()(shared utility):- Correctly parses pickle jpgs + JSON conversation
- Handles missing images/videos gracefully
-
Test
HFEncoderVLMTaskEncoder.encode_sample():- With mock Gemma3 processor: produces correct
input_ids,pixel_values,loss_mask - With mock Pixtral processor: produces correct output
- Loss mask correctly identifies assistant turns
- Samples with no images produce
pixel_values=None - Long sequences are truncated to
seq_length
- With mock Gemma3 processor: produces correct
-
Test
HFEncoderVLMTaskEncoder.batch():- Pads variable-length samples to common max length
- Correctly stacks visual tensors across samples
- Labels have
IGNORE_INDEXfor masked positions
-
Test
HFEncoderVLMTaskEncoder.encode_batch():- Output dict has
visual_inputskey withGenericVisualInputs - Output dict has
input_ids,labels,loss_mask,position_ids,attention_mask
- Output dict has
Integration Tests
File: tests/unit_tests/recipes/test_vlm_energon_recipes.py (new)
-
Test recipe instantiation with
dataset_type="energon":gemma3_vl_*_config(dataset_type="energon", train_data_path=["/fake/path"])creates config withEnergonProviderministral3_*_config(dataset_type="energon", ...)creates config withEnergonProviderglm_45v_*_config(dataset_type="energon", ...)creates config withEnergonProvider- Each provider has the correct task encoder type (
HFEncoderVLMTaskEncoder)
-
Test end-to-end data flow (mock data):
- Create a mock Energon dataset with synthetic images + conversations
- Verify the task encoder produces batches that match what
vlm_step.get_batch_from_iterator()expects - Verify
visual_inputs.normalized_for_model()returns the correct keys for each model type
Manual Validation (GPU required)
-
Run Gemma3-VL training with mock data + energon:
torchrun --nproc-per-node=1 examples/recipes/run_recipe.py \ --recipe gemma3_vl_4b --dataset_type energon --train_data_path /path/to/energon/dataset
-
Verify training loss decreases over 100 iterations
-
Compare loss curves with non-Energon (HF dataset) path to ensure equivalence