Skip to content

Add ModalityTransform Methods in DataConfigs #1

@cagataycali

Description

@cagataycali

Add ModalityTransform Methods in DataConfigs

Description

Add transform methods for ModalityTransform in the DataConfigs. While applying no transforms is an option, adding these methods would be beneficial for subsequent training processes.

References

Current State (strands-robots)

Our BaseDataConfig in strands_robots/policies/groot/data_config.py only has:

  • modality_config() method ✅
  • Missing: transform() method ❌

NVIDIA Implementation Analysis

1. Base Transform Architecture

NVIDIA has an abstract ModalityTransform class in gr00t/data/transform/base.py:

class ModalityTransform(BaseModel, ABC):
    """Abstract class for transforming data modalities."""
    apply_to: list[str]  # Keys to apply transform to
    training: bool = True  # Training vs eval mode
    
    @abstractmethod
    def apply(self, data: dict[str, Any]) -> dict[str, Any]:
        """Apply transformation to data."""
    
    def train(self): self.training = True
    def eval(self): self.training = False

class InvertibleModalityTransform(ModalityTransform):
    @abstractmethod
    def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
        """Reverse transformation."""

class ComposedModalityTransform(ModalityTransform):
    """Compose multiple transforms into a pipeline."""
    transforms: list[ModalityTransform]

2. Transform Pipeline Structure

NVIDIA's BaseDataConfig.transform() returns a ComposedModalityTransform with this pipeline:

def transform(self) -> ModalityTransform:
    transforms = [
        # 1. VIDEO TRANSFORMS
        VideoToTensor(apply_to=self.video_keys),
        VideoCrop(apply_to=self.video_keys, scale=0.95),
        VideoResize(apply_to=self.video_keys, height=224, width=224),
        VideoColorJitter(apply_to=self.video_keys, brightness=0.3, contrast=0.4, saturation=0.5, hue=0.08),
        VideoToNumpy(apply_to=self.video_keys),
        
        # 2. STATE TRANSFORMS
        StateActionToTensor(apply_to=self.state_keys),
        StateActionTransform(apply_to=self.state_keys, normalization_modes={...}),
        
        # 3. ACTION TRANSFORMS
        StateActionToTensor(apply_to=self.action_keys),
        StateActionTransform(apply_to=self.action_keys, normalization_modes={...}),
        
        # 4. CONCAT TRANSFORM
        ConcatTransform(video_concat_order=..., state_concat_order=..., action_concat_order=...),
        
        # 5. MODEL-SPECIFIC TRANSFORM
        GR00TTransform(state_horizon=..., action_horizon=..., max_state_dim=64, max_action_dim=32),
    ]
    return ComposedModalityTransform(transforms=transforms)

3. Key Transform Classes Needed

Video Transforms (video.py)

  • VideoToTensor: numpy [T,H,W,C] uint8 → torch [T,C,H,W] float32
  • VideoCrop: Random/center crop with scale factor
  • VideoResize: Resize to target resolution (224x224)
  • VideoColorJitter: Augmentation (brightness, contrast, saturation, hue)
  • VideoToNumpy: torch → numpy for inference

State/Action Transforms (state_action.py)

  • StateActionToTensor: numpy → torch
  • StateActionTransform: Normalization (min_max, mean_std, q99, binary) + Rotation conversion
  • StateActionSinCosTransform: Sin-cos encoding for joint angles

Concat Transform (concat.py)

  • Concatenates video views along new axis
  • Concatenates state/action keys into single tensors
  • Tracks dimensions for unapply (splitting back)

Model Transform (transforms.py)

  • GR00TTransform: Pads state/action to max dims, applies VLM processing

Implementation Plan

Phase 1: Base Transform Infrastructure

Create strands_robots/policies/groot/transforms/ directory:

strands_robots/policies/groot/transforms/
├── __init__.py
├── base.py           # ModalityTransform, InvertibleModalityTransform, ComposedModalityTransform
├── video.py          # VideoToTensor, VideoCrop, VideoResize, VideoColorJitter, VideoToNumpy
├── state_action.py   # StateActionToTensor, StateActionTransform, StateActionSinCosTransform
└── concat.py         # ConcatTransform

Phase 2: Update BaseDataConfig

Add abstract transform() method to BaseDataConfig:

@dataclass
class BaseDataConfig(ABC):
    # ... existing fields ...
    
    @abstractmethod
    def transform(self) -> ModalityTransform:
        """Return the transform pipeline for this data config."""
        pass

Phase 3: Implement Transforms for Each Config

Update each concrete config (So100DataConfig, FourierGr1ArmsOnlyDataConfig, etc.) with their specific transform pipelines.

Phase 4: Integration with Policy

Update Gr00tPolicy to optionally use transforms for training workflows.

Simplified Initial Implementation (Inference-Only)

For inference-only use cases, we can start with a minimal implementation:

class IdentityTransform(ModalityTransform):
    """No-op transform for inference."""
    def apply(self, data): return data

class BaseDataConfig:
    def transform(self) -> ModalityTransform:
        """Default: no transforms (identity)."""
        return IdentityTransform(apply_to=[])

Full Implementation Priority

Transform Priority Reason
Base classes High Foundation for all transforms
VideoToTensor/ToNumpy High Basic type conversion
StateActionToTensor High Basic type conversion
VideoCrop/Resize Medium Required for proper image preprocessing
StateActionTransform (normalization) Medium Required for training
VideoColorJitter Low Augmentation (training only)
ConcatTransform Medium Required for multi-modal data
GR00TTransform Low Complex, depends on Eagle VLM

Acceptance Criteria

  • ModalityTransform base classes implemented
  • transform() method added to BaseDataConfig
  • Video transforms implemented (at least ToTensor, Resize, ToNumpy)
  • State/Action transforms implemented (at least ToTensor, basic normalization)
  • Each concrete DataConfig has working transform() method
  • Unit tests for transforms
  • Documentation with transform examples
  • Integration tests with GR00T inference

Code Changes Required

File: strands_robots/policies/groot/data_config.py

Before:

@dataclass
class BaseDataConfig(ABC):
    video_keys: List[str]
    # ... other fields ...
    
    def modality_config(self) -> Dict[str, ModalityConfig]:
        # ... existing implementation ...

After:

from .transforms import ModalityTransform, ComposedModalityTransform

@dataclass  
class BaseDataConfig(ABC):
    video_keys: List[str]
    # ... other fields ...
    
    def modality_config(self) -> Dict[str, ModalityConfig]:
        # ... existing implementation ...
    
    @abstractmethod
    def transform(self) -> ModalityTransform:
        """Return the transform pipeline for training/inference."""
        pass

Priority

Medium - Important for training workflows

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions