-
Notifications
You must be signed in to change notification settings - Fork 3
Description
Add ModalityTransform Methods in DataConfigs
Description
Add transform methods for ModalityTransform in the DataConfigs. While applying no transforms is an option, adding these methods would be beneficial for subsequent training processes.
References
- https://github.com/NVIDIA/Isaac-GR00T/blob/main/getting_started/4_deeper_understanding.md#understanding-data-transforms
- https://github.com/NVIDIA/Isaac-GR00T/blob/main/gr00t/experiment/data_config.py
- https://github.com/NVIDIA/Isaac-GR00T/blob/main/gr00t/data/transform/base.py
- https://github.com/NVIDIA/Isaac-GR00T/blob/main/gr00t/data/transform/video.py
- https://github.com/NVIDIA/Isaac-GR00T/blob/main/gr00t/data/transform/state_action.py
- https://github.com/NVIDIA/Isaac-GR00T/blob/main/gr00t/data/transform/concat.py
- https://github.com/NVIDIA/Isaac-GR00T/blob/main/gr00t/model/transforms.py
Current State (strands-robots)
Our BaseDataConfig in strands_robots/policies/groot/data_config.py only has:
modality_config()method ✅- Missing:
transform()method ❌
NVIDIA Implementation Analysis
1. Base Transform Architecture
NVIDIA has an abstract ModalityTransform class in gr00t/data/transform/base.py:
class ModalityTransform(BaseModel, ABC):
"""Abstract class for transforming data modalities."""
apply_to: list[str] # Keys to apply transform to
training: bool = True # Training vs eval mode
@abstractmethod
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
"""Apply transformation to data."""
def train(self): self.training = True
def eval(self): self.training = False
class InvertibleModalityTransform(ModalityTransform):
@abstractmethod
def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
"""Reverse transformation."""
class ComposedModalityTransform(ModalityTransform):
"""Compose multiple transforms into a pipeline."""
transforms: list[ModalityTransform]2. Transform Pipeline Structure
NVIDIA's BaseDataConfig.transform() returns a ComposedModalityTransform with this pipeline:
def transform(self) -> ModalityTransform:
transforms = [
# 1. VIDEO TRANSFORMS
VideoToTensor(apply_to=self.video_keys),
VideoCrop(apply_to=self.video_keys, scale=0.95),
VideoResize(apply_to=self.video_keys, height=224, width=224),
VideoColorJitter(apply_to=self.video_keys, brightness=0.3, contrast=0.4, saturation=0.5, hue=0.08),
VideoToNumpy(apply_to=self.video_keys),
# 2. STATE TRANSFORMS
StateActionToTensor(apply_to=self.state_keys),
StateActionTransform(apply_to=self.state_keys, normalization_modes={...}),
# 3. ACTION TRANSFORMS
StateActionToTensor(apply_to=self.action_keys),
StateActionTransform(apply_to=self.action_keys, normalization_modes={...}),
# 4. CONCAT TRANSFORM
ConcatTransform(video_concat_order=..., state_concat_order=..., action_concat_order=...),
# 5. MODEL-SPECIFIC TRANSFORM
GR00TTransform(state_horizon=..., action_horizon=..., max_state_dim=64, max_action_dim=32),
]
return ComposedModalityTransform(transforms=transforms)3. Key Transform Classes Needed
Video Transforms (video.py)
VideoToTensor: numpy [T,H,W,C] uint8 → torch [T,C,H,W] float32VideoCrop: Random/center crop with scale factorVideoResize: Resize to target resolution (224x224)VideoColorJitter: Augmentation (brightness, contrast, saturation, hue)VideoToNumpy: torch → numpy for inference
State/Action Transforms (state_action.py)
StateActionToTensor: numpy → torchStateActionTransform: Normalization (min_max, mean_std, q99, binary) + Rotation conversionStateActionSinCosTransform: Sin-cos encoding for joint angles
Concat Transform (concat.py)
- Concatenates video views along new axis
- Concatenates state/action keys into single tensors
- Tracks dimensions for unapply (splitting back)
Model Transform (transforms.py)
GR00TTransform: Pads state/action to max dims, applies VLM processing
Implementation Plan
Phase 1: Base Transform Infrastructure
Create strands_robots/policies/groot/transforms/ directory:
strands_robots/policies/groot/transforms/
├── __init__.py
├── base.py # ModalityTransform, InvertibleModalityTransform, ComposedModalityTransform
├── video.py # VideoToTensor, VideoCrop, VideoResize, VideoColorJitter, VideoToNumpy
├── state_action.py # StateActionToTensor, StateActionTransform, StateActionSinCosTransform
└── concat.py # ConcatTransform
Phase 2: Update BaseDataConfig
Add abstract transform() method to BaseDataConfig:
@dataclass
class BaseDataConfig(ABC):
# ... existing fields ...
@abstractmethod
def transform(self) -> ModalityTransform:
"""Return the transform pipeline for this data config."""
passPhase 3: Implement Transforms for Each Config
Update each concrete config (So100DataConfig, FourierGr1ArmsOnlyDataConfig, etc.) with their specific transform pipelines.
Phase 4: Integration with Policy
Update Gr00tPolicy to optionally use transforms for training workflows.
Simplified Initial Implementation (Inference-Only)
For inference-only use cases, we can start with a minimal implementation:
class IdentityTransform(ModalityTransform):
"""No-op transform for inference."""
def apply(self, data): return data
class BaseDataConfig:
def transform(self) -> ModalityTransform:
"""Default: no transforms (identity)."""
return IdentityTransform(apply_to=[])Full Implementation Priority
| Transform | Priority | Reason |
|---|---|---|
| Base classes | High | Foundation for all transforms |
| VideoToTensor/ToNumpy | High | Basic type conversion |
| StateActionToTensor | High | Basic type conversion |
| VideoCrop/Resize | Medium | Required for proper image preprocessing |
| StateActionTransform (normalization) | Medium | Required for training |
| VideoColorJitter | Low | Augmentation (training only) |
| ConcatTransform | Medium | Required for multi-modal data |
| GR00TTransform | Low | Complex, depends on Eagle VLM |
Acceptance Criteria
-
ModalityTransformbase classes implemented -
transform()method added toBaseDataConfig - Video transforms implemented (at least ToTensor, Resize, ToNumpy)
- State/Action transforms implemented (at least ToTensor, basic normalization)
- Each concrete DataConfig has working
transform()method - Unit tests for transforms
- Documentation with transform examples
- Integration tests with GR00T inference
Code Changes Required
File: strands_robots/policies/groot/data_config.py
Before:
@dataclass
class BaseDataConfig(ABC):
video_keys: List[str]
# ... other fields ...
def modality_config(self) -> Dict[str, ModalityConfig]:
# ... existing implementation ...After:
from .transforms import ModalityTransform, ComposedModalityTransform
@dataclass
class BaseDataConfig(ABC):
video_keys: List[str]
# ... other fields ...
def modality_config(self) -> Dict[str, ModalityConfig]:
# ... existing implementation ...
@abstractmethod
def transform(self) -> ModalityTransform:
"""Return the transform pipeline for training/inference."""
passPriority
Medium - Important for training workflows