Skip to content

Energon Dataloader for All HF-based encoders #2573

@cuichenx

Description

@cuichenx

Plan: Energon Dataloader Support for HF-Encoder VL Models

Context

PR #2440 adds Energon dataloader support for Qwen3-VL, which uses a Megatron-native vision encoder. The QwenVLTaskEncoder is tightly coupled to Qwen's architecture (image token expansion based on grid THW, Qwen-specific process_vision, Qwen2_5_VLVisualInputs container).

HF-encoder VL models use HuggingFace vision towers instead of Megatron-native encoders. These models handle image processing through their HF processor (AutoProcessor) and don't need the Qwen-specific token expansion logic. The target models are:

Model Vision Encoder HF Processor Notes
Gemma3-VL SiglipVisionModel (HF) Gemma3Processor Fixed 256 tokens/image, bidirectional attention mask
Ministral3 PixtralVisionModel (HF) PixtralProcessor Variable patch sizes
GLM-4.5V Glm4vVisionModel (HF) Glm4vProcessor Needs grid THW + MRoPE (hybrid)

Goal: Create a generic Energon task encoder that works for all HF-encoder VL models by delegating vision processing to the model's HF processor, then integrate it into each model's recipe.


Step 1: Create Generic Visual Inputs Container

File: src/megatron/bridge/training/utils/visual_inputs.py (modify existing)

Add a GenericVisualInputs dataclass alongside the existing Qwen2_5_VLVisualInputs:

@dataclass
class GenericVisualInputs:
    """Container for visual modality tensors from HF-encoder VL models.

    Holds the superset of visual fields needed across different HF-encoder models.
    Only non-None fields are passed to the model via normalized_for_model().
    """
    pixel_values: Optional[torch.Tensor] = None
    pixel_values_videos: Optional[torch.Tensor] = None
    image_grid_thw: Optional[torch.Tensor] = None
    video_grid_thw: Optional[torch.Tensor] = None
    image_sizes: Optional[torch.Tensor] = None

    def as_model_kwargs(self) -> dict[str, torch.Tensor]:
        """Return a mapping of non-None fields suitable for model forward kwargs."""
        return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None}

    def normalized_for_model(self) -> dict[str, torch.Tensor]:
        """Return non-None fields. No shape transformation needed for HF-encoder models."""
        return self.as_model_kwargs()

This works with the existing vlm_step.py code which iterates val.__dict__ to move tensors to CUDA and calls visual_inputs.normalized_for_model() to get model kwargs.


Step 2: Create Generic HF-Encoder Task Encoder

File: src/megatron/bridge/data/energon/hf_encoder_task_encoder.py (new file)

Design

The key insight: for HF-encoder models, the HF processor already handles both tokenization and image processing via apply_chat_template() and processor(). We don't need custom token expansion logic like Qwen3-VL. We just need to:

  1. Convert raw Energon samples into conversation format
  2. Run the HF processor to get input_ids, pixel_values, etc.
  3. Create loss masks (reuse the existing search-based pattern from collate.py)
  4. Package into batches with GenericVisualInputs

Class Structure

class HFEncoderVLMTaskEncoder(DefaultTaskEncoder[ChatMLSample, HFEncoderTaskSample, HFEncoderTaskBatch, dict]):

Reuses the existing ChatMLSample and cook_chatml_sample() from recipes/qwen_vl/data/energon/task_encoder.py (move shared code to common location or import).

Key Methods

__init__(processor, tokenizer, seq_length, visual_keys):

  • processor: HF AutoProcessor instance (handles images + text)
  • tokenizer: HF tokenizer (for loss mask creation)
  • seq_length: max sequence length
  • visual_keys: list of keys to extract from processor output as visual inputs (e.g., ["pixel_values"] for Gemma3/Ministral3, ["pixel_values", "image_grid_thw"] for GLM-4.5V)

encode_sample(sample: ChatMLSample) -> HFEncoderTaskSample:

  1. Parse conversation from sample.conversation (JSON string)
  2. Convert to HF conversation format (same logic as QwenVLTaskEncoder conversation parsing)
  3. Process images: convert sample.imgs (list of PIL images) to the format processor expects
  4. Run processor.apply_chat_template(conversation, tokenize=True, ...) to get input_ids
  5. Run processor(images=images, ...) to get pixel_values and other vision tensors
  6. Create loss mask using the search-based pattern from collate.py:create_multiturn_loss_mask_by_search()
  7. Return HFEncoderTaskSample with all fields

batch(samples: List[HFEncoderTaskSample]) -> HFEncoderTaskBatch:

  1. Pad input_ids to max length in batch
  2. Stack/concatenate visual tensors (pixel_values, etc.)
  3. Create attention_mask, position_ids, loss_mask, labels
  4. Return HFEncoderTaskBatch

encode_batch(batch: HFEncoderTaskBatch) -> dict:

  1. Convert to dict
  2. Wrap visual tensors in GenericVisualInputs
  3. Return dict ready for vlm_step.py

Reused utilities (from existing code)

  • cook_chatml_sample() from recipes/qwen_vl/data/energon/task_encoder.py - sample parsing
  • create_multiturn_loss_mask_by_search() from data/vlm_datasets/collate.py - loss masking
  • extract_skipped_token_ids() from data/vlm_datasets/token_utils.py - pad token identification
  • get_ltor_masks_and_position_ids() from recipes/qwen_vl/data/energon/task_encoder.py - attention mask/position IDs
  • Cooker / basic_sample_keys from megatron.energon - sample cooking

Data format

Same pickle-based webdataset format as Qwen3-VL (jpgs pickle field + json conversation field). The cook_chatml_sample() function already handles this.


Step 3: Refactor Shared Code

Move shared code from the Qwen3-VL task encoder to a common location so both encoders can use it.

File: src/megatron/bridge/data/energon/task_encoder_utils.py (new file)

Move these from recipes/qwen_vl/data/energon/task_encoder.py:

  • ChatMLSample dataclass
  • cook_chatml_sample() function
  • get_ltor_masks_and_position_ids() function
  • find_pattern_indices() function
  • IGNORE_INDEX constant

The existing QwenVLTaskEncoder should then import from this common location (backward-compatible).


Step 4: Recipe Integration

Add dataset_type="energon" option to each HF-encoder VL model recipe:

4a. Gemma3-VL Recipe

File: src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py

Add an elif _dataset_choice == "energon": branch that:

  • Creates AutoProcessor.from_pretrained(hf_path)
  • Instantiates HFEncoderVLMTaskEncoder(processor=processor, tokenizer=tokenizer, seq_length=seq_length, visual_keys=["pixel_values"])
  • Creates EnergonProvider(task_encoder=task_encoder, ...)

4b. Ministral3 Recipe

File: src/megatron/bridge/recipes/ministral3/ministral3.py

Same pattern. Uses PixtralProcessor via AutoProcessor.from_pretrained().
visual_keys=["pixel_values"]

4c. GLM-4.5V Recipe

File: src/megatron/bridge/recipes/glm_vl/glm_45v.py

Same pattern. Uses GLM processor.
visual_keys=["pixel_values", "pixel_values_videos", "image_grid_thw", "video_grid_thw"] (GLM needs grid info like Qwen)


Step 5: Fix Ministral3 Collate (side fix)

File: src/megatron/bridge/data/vlm_datasets/collate.py

The existing ministral3_collate_fn does NOT wrap visual tensors in a visual_inputs container (unlike qwen2_5_collate_fn and default_collate_fn). This means vlm_step.py would receive visual_inputs=None and drop the pixel values. Add the GenericVisualInputs wrapper to ministral3_collate_fn for consistency.


Files Modified/Created

File Action Description
src/megatron/bridge/training/utils/visual_inputs.py Modify Add GenericVisualInputs class
src/megatron/bridge/data/energon/task_encoder_utils.py New Shared utilities (ChatMLSample, cook_chatml_sample, etc.)
src/megatron/bridge/data/energon/hf_encoder_task_encoder.py New Generic HF-encoder task encoder
src/megatron/bridge/data/energon/__init__.py Modify Export new task encoder
src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py Modify Add energon dataset_type
src/megatron/bridge/recipes/ministral3/ministral3.py Modify Add energon dataset_type
src/megatron/bridge/recipes/glm_vl/glm_45v.py Modify Add energon dataset_type
src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py Modify Import shared code from utils
src/megatron/bridge/data/vlm_datasets/collate.py Modify Fix ministral3 visual_inputs wrapping

Testing Plan

Unit Tests

File: tests/unit_tests/data/energon/test_hf_encoder_task_encoder.py (new)

  1. Test GenericVisualInputs:

    • normalized_for_model() returns only non-None fields
    • CUDA transfer works via __dict__ iteration
    • Works with pixel_values only (Gemma3/Ministral3 case)
    • Works with pixel_values + image_grid_thw (GLM case)
  2. Test cook_chatml_sample() (shared utility):

    • Correctly parses pickle jpgs + JSON conversation
    • Handles missing images/videos gracefully
  3. Test HFEncoderVLMTaskEncoder.encode_sample():

    • With mock Gemma3 processor: produces correct input_ids, pixel_values, loss_mask
    • With mock Pixtral processor: produces correct output
    • Loss mask correctly identifies assistant turns
    • Samples with no images produce pixel_values=None
    • Long sequences are truncated to seq_length
  4. Test HFEncoderVLMTaskEncoder.batch():

    • Pads variable-length samples to common max length
    • Correctly stacks visual tensors across samples
    • Labels have IGNORE_INDEX for masked positions
  5. Test HFEncoderVLMTaskEncoder.encode_batch():

    • Output dict has visual_inputs key with GenericVisualInputs
    • Output dict has input_ids, labels, loss_mask, position_ids, attention_mask

Integration Tests

File: tests/unit_tests/recipes/test_vlm_energon_recipes.py (new)

  1. Test recipe instantiation with dataset_type="energon":

    • gemma3_vl_*_config(dataset_type="energon", train_data_path=["/fake/path"]) creates config with EnergonProvider
    • ministral3_*_config(dataset_type="energon", ...) creates config with EnergonProvider
    • glm_45v_*_config(dataset_type="energon", ...) creates config with EnergonProvider
    • Each provider has the correct task encoder type (HFEncoderVLMTaskEncoder)
  2. Test end-to-end data flow (mock data):

    • Create a mock Energon dataset with synthetic images + conversations
    • Verify the task encoder produces batches that match what vlm_step.get_batch_from_iterator() expects
    • Verify visual_inputs.normalized_for_model() returns the correct keys for each model type

Manual Validation (GPU required)

  1. Run Gemma3-VL training with mock data + energon:

    torchrun --nproc-per-node=1 examples/recipes/run_recipe.py \
      --recipe gemma3_vl_4b --dataset_type energon --train_data_path /path/to/energon/dataset
  2. Verify training loss decreases over 100 iterations

  3. Compare loss curves with non-Energon (HF dataset) path to ensure equivalence

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions