Skip to content

Alignment #20

@bl-wang1

Description

@bl-wang1

Dear authors,
thank you for providing the source code of realtime-vla.

My converted checkpoint does not produce the same action as the action generated by the original checkpoint. I also see this issue #19 which has similar problem. But when I use the script in issue19, it still dosen't work.

Triton actions range: [-0.992531, 1.050355]
JAX actions range: [-1.000153, 1.015160]
Triton vs JAX MAE: 0.110080
Per-dimension MAE:
Joint 1: 0.204920
Joint 2: 0.090841
Joint 3: 0.258178
Joint 4: 0.059499
Joint 5: 0.106889
Joint 6: 0.029172
Gripper Width: 0.021063

I used pi0_libero checkpoint:

checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_libero")

Then convert to triton based:

python convert_from_jax.py --jax_path ~/.cache/openp
i/openpi-assets/checkpoints/pi0_libero --output ~/workspace/ckpt/pi0_libero.pkl --prompt "Sort bowls and paper cups into their desig
nated places" --tokenizer_path ~/workspace/ckpt/pretrain/paligemma-3b-pt-224

Test script:

import numpy as np
import torch
import argparse
from PIL import Image
import einops
import json
import cv2
import os
from dexmal.pi0_infer import Pi0Inference
from openpi.training import config as _config
from openpi.policies import policy_config as _policy_config
import pickle

class Pi0ModelEvaluator:
    def __init__(self, task, model_type: str, triton_path: str, jax_path: str, norm_stats_dir:str, config_name: str):
        self.triton_path = triton_path
        self.jax_path = jax_path
        self.task = task
        self.model_type = model_type
        self.norm_stats_dir = norm_stats_dir
        self.policy = None
        self.config_name = config_name
        self.norm_stats = None
        self.results = {
            'episode_results': []
        }
        if self.model_type == "triton":
            self.policy, self.norm_stats = self._load_model(model_type=self.model_type)
        elif self.model_type == "jax":
            self.policy = self._load_jax_model()

    def _load_jax_model(self):
        config = _config.get_config(self.config_name)
        policy = _policy_config.create_trained_policy(config, self.jax_path)
        return policy

    def _load_model(self, model_type="triton"):
        with open(self.triton_path, 'rb') as f:
            weights = pickle.load(f)
        policy = Pi0Inference(checkpoint=weights, num_views=2, chunk_size=50)
        norm_stats = self._load_norm_stats(self.norm_stats_dir) if self.norm_stats_dir else None
        return policy, norm_stats

    def _load_norm_stats(self, norm_stats_dir: str) -> dict:
        norm_stats_path = os.path.join(norm_stats_dir, "norm_stats.json")
        if os.path.exists(norm_stats_path):
            with open(norm_stats_path, 'r') as f:
                return json.load(f)['norm_stats']
        return None
    
    def _parse_image(self, image) -> np.ndarray:
        image = np.asarray(image)
        if np.issubdtype(image.dtype, np.floating):
            image = (255 * image).astype(np.uint8)
        if image.shape[0] == 3:
            image = einops.rearrange(image, "c h w -> h w c")
        return image

    def _pad_to_dim(self, x: np.ndarray, target_dim: int, axis: int = -1) -> np.ndarray:
        current_dim = x.shape[axis]
        if current_dim < target_dim:
            pad_width = [(0, 0)] * len(x.shape)
            pad_width[axis] = (0, target_dim - current_dim)
            return np.pad(x, pad_width)
        return x

    def _resize_with_pad(self, image: np.ndarray, height: int = 224, width: int = 224) -> np.ndarray:
        pil_image = Image.fromarray(image)
        cur_width, cur_height = pil_image.size
        if cur_width == width and cur_height == height:
            return image
        
        ratio = max(cur_width / width, cur_height / height)
        resized_height = int(cur_height / ratio)
        resized_width = int(cur_width / ratio)
        resized_image = pil_image.resize((resized_width, resized_height), resample=Image.BILINEAR)
        zero_image = Image.new(resized_image.mode, (width, height), 0)
        pad_height = max(0, int((height - resized_height) / 2))
        pad_width = max(0, int((width - resized_width) / 2))
        zero_image.paste(resized_image, (pad_width, pad_height))
        return np.array(zero_image)

    def _normalize_image(self, image: np.ndarray) -> np.ndarray:
        image = image.astype(np.float32) / 255.0 * 2.0 - 1.0
        return image

    def _normalize_state(self, state: np.ndarray, norm_stats: dict) -> np.ndarray:
        if norm_stats and "state" in norm_stats:
            state_mean = np.array(norm_stats["state"]["mean"])
            state_mean = self._pad_to_dim(state_mean, 32)
            state_std = np.array(norm_stats["state"]["std"])
            state_std = self._pad_to_dim(state_std, 32)
            return (state - state_mean) / (state_std + 1e-6)
        return None

    def _unnormalize_state(self, actions: np.ndarray, norm_stats: dict) -> np.ndarray:
        if norm_stats and "actions" in norm_stats:
            actions_mean = np.array(norm_stats["actions"]["mean"])
            actions_mean = self._pad_to_dim(actions_mean, 32)
            actions_std = np.array(norm_stats["actions"]["std"])
            actions_std = self._pad_to_dim(actions_std, 32)
            return actions * (actions_std + 1e-6) + actions_mean
        return None

    def _apply_input_transforms(self, data: dict, action_dim: int = 32, norm_stats: dict = None) -> dict:
        state = self._pad_to_dim(data["state"], action_dim)
        state = self._normalize_state(state, norm_stats)
        base_image = self._parse_image(data["base_0_rgb"])
        left_wrist_image = self._parse_image(data["left_wrist_0_rgb"])
        right_wrist_image = self._parse_image(data["right_wrist_0_rgb"])
        base_image = self._resize_with_pad(base_image, 224, 224)
        base_image = self._normalize_image(base_image)
        left_wrist_image = self._resize_with_pad(left_wrist_image, 224, 224)
        left_wrist_image = self._normalize_image(left_wrist_image)
        right_wrist_image = self._resize_with_pad(right_wrist_image, 224, 224)
        right_wrist_image = self._normalize_image(right_wrist_image)
        image_dict = {
            "base_0_rgb": base_image,
            "left_wrist_0_rgb": left_wrist_image,
            "right_wrist_0_rgb": right_wrist_image,
        }
        image_mask_dict = {
            "base_0_rgb": np.True_,
            "left_wrist_0_rgb": np.False_,
            "right_wrist_0_rgb": np.False_,
        }
        inputs = {
            "state": state,
            "image": image_dict,
            "image_mask": image_mask_dict,
        }
        if "prompt" in data:
            inputs["prompt"] = data["prompt"]
        return inputs

    def infer(self, inputs: dict, noise: np.ndarray) -> dict:
        if self.model_type == "triton":
            transformed_inputs = self._apply_input_transforms(inputs, action_dim=32, norm_stats=self.norm_stats)
            images = []
            for view in ["base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb"]:
                img = transformed_inputs["image"][view]
                images.append(torch.from_numpy(img))
            
            observation_images = torch.stack(images[0:2], dim=0).to(torch.float32)
            observation_state = torch.from_numpy(transformed_inputs["state"].astype(np.float32)).unsqueeze(0).to(torch.float32)
            normalized_observation_images = observation_images.cuda()
            normalized_observation_state = observation_state.cuda()

            diffusion_noise = torch.from_numpy(noise).to(torch.float32).cuda()
            self.policy.buffers['observation_images_normalized'].copy_(normalized_observation_images)
            self.policy.buffers['observation_state_normalized'].copy_(normalized_observation_state.squeeze(0))
            self.policy.buffers['diffusion_noise'].copy_(diffusion_noise)
            self.policy.record_run()
            actions = self.policy.buffers['diffusion_noise']
            actions = actions.cpu().float().numpy()
            actions = self._unnormalize_state(actions, self.norm_stats)[:, :7]
            actions[..., :6] = actions[..., :6] + inputs["state"][..., :6]
            # actions = actions[:, :7]
            return {
                "actions": actions
            }
        elif self.model_type == "jax":
            actions = self.policy.infer(inputs, noise=noise)
            # actions["actions"] = actions["actions"][:, :7]
            return actions

def main():
    parser = argparse.ArgumentParser(description="VLA Model Inference")
    parser.add_argument('--triton_path', type=str, default='~/workspace/ckpt/pi0_libero.pkl ')
    parser.add_argument('--jax_path', type=str, default='~/.cache/openpi/openpi-assets/checkpoints/pi0_libero')
    parser.add_argument('--norm_stats_dir', type=str, default='~/.cache/openpi/openpi-assets/checkpoints/pi0_libero/assets/physical-intelligence/libero/')
    parser.add_argument('--config_name', type=str, default='pi0_libero')

    args = parser.parse_args()

    example_image_global = cv2.imread("~/workspace/data/1.png")
    example_image_hand = cv2.imread("~/workspace/data/2.png")
    noise1 = torch.tensor([[-1.2578, -0.4023, -1.1250,  0.8789, -0.8633,  0.3457,  0.9414, -0.1226,
      -1.1875,  0.0713, -1.7578, -1.0234,  1.0312, -0.0031,  0.0167,  0.0422,
      -2.5938,  0.0087,  0.8398, -0.0830,  1.0156, -0.1553, -0.0806,  0.2988,
       1.0312, -1.3281, -0.9062,  0.2754,  0.9336, -0.0457,  0.0757, -0.0317]], device='cuda:0', dtype=torch.bfloat16)
    noises = noise1.expand(50, -1)
    noise = noises.float().cpu().numpy()
    noise = np.array(noise, dtype=np.float32)
    
    state = np.array([-0.2355, 0.1189, 1.1514, 3.1155, 0.0199, -0.1378, 0.0395, -0.0393], dtype=np.float32)

    inputs_triton = {
        "base_0_rgb": example_image_global,
        "left_wrist_0_rgb": example_image_hand,
        "right_wrist_0_rgb": example_image_hand,
        "state": state,
        "prompt": "Sort bowls and paper cups into their designated places"
    } 
    inputs_jax = {
        "observation/image": example_image_global,
        "observation/wrist_image": example_image_hand,
        "observation/state": state,
        "prompt": "Sort bowls and paper cups into their designated places"
    }

    pi0_triton = Pi0ModelEvaluator(task='check_consistency', model_type="triton", triton_path=args.triton_path, jax_path=args.jax_path,
                                    norm_stats_dir=args.norm_stats_dir, config_name=args.config_name)
    
    result_triton = pi0_triton.infer(inputs_triton, noise)
    del pi0_triton
    print("finish triton")
    pi0_jax = Pi0ModelEvaluator(task='check_consistency', model_type="jax", triton_path=args.triton_path, jax_path=args.jax_path,
                                    norm_stats_dir=args.norm_stats_dir, config_name=args.config_name)
    result_jax = pi0_jax.infer(inputs_jax, noise)
    del pi0_jax
    print("finish jax")
    print(f"Triton actions range: [{result_triton['actions'].min():.6f}, {result_triton['actions'].max():.6f}]")
    print(f"JAX actions range: [{result_jax['actions'].min():.6f}, {result_jax['actions'].max():.6f}]")

    joint_names = ['Joint 1', 'Joint 2', 'Joint 3', 'Joint 4', 'Joint 5', 'Joint 6', 'Gripper Width']
    triton_jax_mae = np.mean(np.abs(result_triton['actions'] - result_jax['actions']))
    print(f"Triton vs JAX MAE: {triton_jax_mae:.6f}")
    print("Per-dimension MAE:")
    for i in range(7):
        dim_mae = np.mean(np.abs(result_triton['actions'][:, i] - result_jax['actions'][:, i]))
        print(f"{joint_names[i]}: {dim_mae:.6f}")

    print("Triton actions:", result_triton['actions'][:3])
    print("JAX actions:", result_jax['actions'][:3])


if __name__ == "__main__":
    main()

What's more, I add two lines to handle noise data format in pi0.py :

    @override
    def sample_actions(
        self,
        rng: at.KeyArrayLike,
        observation: _model.Observation,
        *,
        num_steps: int | at.Int[at.Array, ""] = 10,
        noise: at.Float[at.Array, "b ah ad"] | None = None,
    ) -> _model.Actions:
        observation = _model.preprocess_observation(None, observation, train=False)
        # note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
        # distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
        dt = -1.0 / num_steps
        batch_size = observation.state.shape[0]
        if noise is None:
            noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))
        test_noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))
        noise = jnp.asarray(noise, dtype=test_noise.dtype)

The library versions(and hardware) used in the comparison are:
torch: 2.8.0
triton: 3.4.0,
CUDA: 12.6,
GPU: RTX5090

Is there any way to verify where the issue occurred? Greatly appreciate for any guidance.
Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions