Skip to content

Pi05 libero alignment #31

@gitarya

Description

@gitarya

Hello:
Thanks for your great work!
I have performed a precision alignment test between the JAX (Reference) and Triton (Inference) backends for the pi0.5 model on the LIBERO dataset. While the inference latency is optimized in Triton (~43ms), there is a noticeable Mean Absolute Error (MAE) in the action space, particularly concentrated in Dimension 1.

This discrepancy likely explains the performance drop observed when deploying the model via Triton-based inference engines (e.g., LeRobot) compared to the official JAX results.

=== PROMPT: pick up the black bowl in the top drawer of the wooden cabinet and place it on the plate ===
Triton GPU  time: 42.81 ms (CUDA Event, avg 20 runs)
Triton Wall time: 42.76 ms (total end-to-end, avg 20 runs)
Triton actions range: [-0.469692, 1.002585]
JAX    actions range: [-0.470159, 1.003163]
Triton vs JAX MAE: 0.002777
Per-dimension MAE:
  dim 0: 0.002738
  dim 1: 0.010315
  dim 2: 0.002292
  dim 3: 0.000328
  dim 4: 0.001432
  dim 5: 0.000470
  dim 6: 0.001863
=== PROMPT: open the top drawer of the wooden cabinet ===
Triton GPU  time: 43.32 ms (CUDA Event, avg 20 runs)
Triton Wall time: 43.30 ms (total end-to-end, avg 20 runs)
Triton actions range: [-0.534057, 0.998606]
JAX    actions range: [-0.533780, 1.001180]
Triton vs JAX MAE: 0.002332
Per-dimension MAE:
  dim 0: 0.004500
  dim 1: 0.004662
  dim 2: 0.001603
  dim 3: 0.000440
  dim 4: 0.001742
  dim 5: 0.000989
  dim 6: 0.002390
=== PROMPT: put the red mug on the coffee machine ===
Triton GPU  time: 43.29 ms (CUDA Event, avg 20 runs)
Triton Wall time: 43.27 ms (total end-to-end, avg 20 runs)
Triton actions range: [-0.355880, 0.999427]
JAX    actions range: [-0.352955, 1.000402]
Triton vs JAX MAE: 0.003484
Per-dimension MAE:
  dim 0: 0.002887
  dim 1: 0.014394
  dim 2: 0.002141
  dim 3: 0.000637
  dim 4: 0.001431
  dim 5: 0.000826
  dim 6: 0.002071
=== PROMPT: pick up the fork and place it on the plate ===
Triton GPU  time: 43.41 ms (CUDA Event, avg 20 runs)
Triton Wall time: 43.63 ms (total end-to-end, avg 20 runs)
Triton actions range: [-0.497158, 0.997394]
JAX    actions range: [-0.498947, 0.997480]
Triton vs JAX MAE: 0.002958
Per-dimension MAE:
  dim 0: 0.003967
  dim 1: 0.010362
  dim 2: 0.002003
  dim 3: 0.000344
  dim 4: 0.001341
  dim 5: 0.000599
  dim 6: 0.002093
=== PROMPT: close the top drawer of the wooden cabinet ===
Triton GPU  time: 43.43 ms (CUDA Event, avg 20 runs)
Triton Wall time: 43.43 ms (total end-to-end, avg 20 runs)
Triton actions range: [-0.380434, 0.995849]
JAX    actions range: [-0.382129, 0.998579]
Triton vs JAX MAE: 0.001578
Per-dimension MAE:
  dim 0: 0.003547
  dim 1: 0.002827
  dim 2: 0.001437
  dim 3: 0.000338
  dim 4: 0.000293
  dim 5: 0.000294
  dim 6: 0.002311

This precision discrepancy leads to a 2% drop in success rate on the LIBERO benchmark.

  Libero_Spatial Libero_10 Libero_Obj Libero_Goal
JAX 98.8% 92.4% 98.2% 98%
Triton 97% 90% 97.6% 96%


is this result correct? How can I improve precision?
Here is my code:

import numpy as np
import torch
import argparse
from PIL import Image
import einops
import json
import cv2
import os
import pickle

from pi05_infer import Pi05Inference
from openpi.training import config as _config
from openpi.policies import policy_config as _policy_config


class Pi05ModelEvaluator:
    def __init__(self, task, model_type: str, triton_path: str, jax_path: str,
                 norm_stats_dir: str, config_name: str, prompt: str = "do something",
                 tokenizer_path: str = None, action_dim: int = 7):
        self.triton_path = triton_path
        self.jax_path = jax_path
        self.task = task
        self.model_type = model_type
        self.norm_stats_dir = norm_stats_dir
        self.config_name = config_name
        self.prompt = prompt
        self.tokenizer_path = tokenizer_path
        self.action_dim = action_dim
        self.policy = None
        self.norm_stats = None

        if self.model_type == "triton":
            self.policy, self.norm_stats = self._load_triton_model()
        elif self.model_type == "jax":
            self.policy = self._load_jax_model()

        # q01/q99 for state and actions
        self._digitize_bins = np.linspace(-1, 1, 256 + 1)[:-1]
        self._state_q01 = np.array(self.norm_stats["state"]["q01"]) if self.norm_stats else None
        self._state_q99 = np.array(self.norm_stats["state"]["q99"]) if self.norm_stats else None
        self._actions_q01 = np.array(self.norm_stats["actions"]["q01"]) if self.norm_stats else None
        self._actions_q99 = np.array(self.norm_stats["actions"]["q99"]) if self.norm_stats else None

    def _load_jax_model(self):
        config = _config.get_config(self.config_name)
        policy = _policy_config.create_trained_policy(config, self.jax_path)
        return policy

    def _load_triton_model(self):
        with open(self.triton_path, 'rb') as f:
            weights = pickle.load(f)
        norm_stats = self._load_norm_stats(self.norm_stats_dir) if self.norm_stats_dir else None
        policy = Pi05Inference(
            checkpoint=weights,
            num_views=2,
            chunk_size=50,
            tokenizer_path=self.tokenizer_path,
            max_tokenize_len=200,
            max_prompt_text=self.prompt,
            discrete_state_input=False,
        )
        return policy, norm_stats

    def _load_norm_stats(self, norm_stats_dir: str) -> dict:
        norm_stats_path = os.path.join(norm_stats_dir, "norm_stats.json")
        if os.path.exists(norm_stats_path):
            with open(norm_stats_path, 'r') as f:
                return json.load(f)['norm_stats']
        return None

    def _parse_image(self, image) -> np.ndarray:
        image = np.asarray(image)
        if np.issubdtype(image.dtype, np.floating):
            image = (255 * image).astype(np.uint8)
        if image.shape[0] == 3:
            image = einops.rearrange(image, "c h w -> h w c")
        return image

    def _pad_to_dim(self, x: np.ndarray, target_dim: int, axis: int = -1) -> np.ndarray:
        current_dim = x.shape[axis]
        if current_dim < target_dim:
            pad_width = [(0, 0)] * len(x.shape)
            pad_width[axis] = (0, target_dim - current_dim)
            return np.pad(x, pad_width)
        return x

    def _resize_with_pad(self, image: np.ndarray, height: int = 224, width: int = 224) -> np.ndarray:
        pil_image = Image.fromarray(image)
        cur_width, cur_height = pil_image.size
        if cur_width == width and cur_height == height:
            return image
        ratio = max(cur_width / width, cur_height / height)
        resized_height = int(cur_height / ratio)
        resized_width = int(cur_width / ratio)
        resized_image = pil_image.resize((resized_width, resized_height), resample=Image.BILINEAR)
        zero_image = Image.new(resized_image.mode, (width, height), 0)
        pad_height = max(0, int((height - resized_height) / 2))
        pad_width = max(0, int((width - resized_width) / 2))
        zero_image.paste(resized_image, (pad_width, pad_height))
        return np.array(zero_image)

    def _normalize_image(self, image: np.ndarray) -> np.ndarray:
        return image.astype(np.float32) / 255.0 * 2.0 - 1.0

    def _unnormalize_actions(self, actions: np.ndarray) -> np.ndarray:
        q01 = self._pad_to_dim(self._actions_q01, 32)
        q99 = self._pad_to_dim(self._actions_q99, 32)
        return (actions + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01

    def infer(self, inputs: dict, noise: np.ndarray) -> dict:
        if self.model_type == "triton":
            imgs = []
            for view in ["base_0_rgb", "left_wrist_0_rgb"]:
                img = self._parse_image(inputs[view])
                img = self._resize_with_pad(img, 224, 224)
                img = self._normalize_image(img)
                imgs.append(torch.from_numpy(img))
            observation_images = torch.stack(imgs, dim=0).to(torch.float32).cuda(non_blocking=True)
            diffusion_noise = torch.from_numpy(noise).to(torch.float32).cuda(non_blocking=True)
            # discrete_state_input=False: state_tokens 不使用,传 None
            actions = self.policy.forward(observation_images, diffusion_noise, inputs.get("prompt"), None)
            actions = actions.cpu().float().numpy()
            actions = self._unnormalize_actions(actions)[:, :self.action_dim]
            return {"actions": actions}

        elif self.model_type == "jax":
            actions = self.policy.infer(inputs, noise=noise)
            actions["actions"] = actions["actions"][:, :self.action_dim]
            return actions


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--triton_path', type=str, required=True)
    parser.add_argument('--jax_path', type=str, required=True)
    parser.add_argument('--norm_stats_dir', type=str, required=True)
    parser.add_argument('--config_name', type=str, default='pi05_libero')
    parser.add_argument('--prompt', type=str, default='pick up the fork')
    parser.add_argument('--tokenizer_path', type=str, required=True)
    parser.add_argument('--action_dim', type=int, default=7)
    args = parser.parse_args()

    example_image_global = cv2.imread("image1.png")
    example_image_left   = cv2.imread("image2.png")
    example_image_right  = cv2.imread("image3.png")

    np.random.seed(42)
    noise = np.random.randn(50, 32).astype(np.float32)
    np.random.seed(42)
    state = np.random.randn(8).astype(np.float32)

    inputs_triton = {
        "base_0_rgb":       example_image_global,
        "left_wrist_0_rgb": example_image_left,
        "right_wrist_0_rgb": example_image_right,
        "state":  state,
        "prompt": args.prompt,
    }
    inputs_jax = {
        "observation/image":      example_image_global,
        "observation/wrist_image": example_image_left,
        "observation/state":      state,
        "prompt": args.prompt,
    }

    common = dict(
        task='check_consistency',
        triton_path=args.triton_path,
        jax_path=args.jax_path,
        norm_stats_dir=args.norm_stats_dir,
        config_name=args.config_name,
        prompt=args.prompt,
        tokenizer_path=args.tokenizer_path,
        action_dim=args.action_dim,
    )

    # Triton
    pi_triton = Pi05ModelEvaluator(model_type='triton', **common)
    result_triton = pi_triton.infer(inputs_triton, noise)
    print("=== Triton done ===")

    # JAX
    pi_jax = Pi05ModelEvaluator(model_type='jax', **common)
    result_jax = pi_jax.infer(inputs_jax, noise)
    print("=== JAX done ===")

    t = result_triton['actions']
    j = result_jax['actions']
    mae = np.mean(np.abs(t - j))
    print(f"Triton actions range: [{t.min():.6f}, {t.max():.6f}]")
    print(f"JAX    actions range: [{j.min():.6f}, {j.max():.6f}]")
    print(f"Triton vs JAX MAE: {mae:.6f}")
    print("Per-dimension MAE:")
    for i in range(args.action_dim):
        dim_mae = np.mean(np.abs(t[:, i] - j[:, i]))
        print(f"  dim {i}: {dim_mae:.6f}")


if __name__ == "__main__":
    main()

Thanks a lot!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions