This discrepancy likely explains the performance drop observed when deploying the model via Triton-based inference engines (e.g., LeRobot) compared to the official JAX results.
=== PROMPT: pick up the black bowl in the top drawer of the wooden cabinet and place it on the plate ===
Triton GPU time: 42.81 ms (CUDA Event, avg 20 runs)
Triton Wall time: 42.76 ms (total end-to-end, avg 20 runs)
Triton actions range: [-0.469692, 1.002585]
JAX actions range: [-0.470159, 1.003163]
Triton vs JAX MAE: 0.002777
Per-dimension MAE:
dim 0: 0.002738
dim 1: 0.010315
dim 2: 0.002292
dim 3: 0.000328
dim 4: 0.001432
dim 5: 0.000470
dim 6: 0.001863
=== PROMPT: open the top drawer of the wooden cabinet ===
Triton GPU time: 43.32 ms (CUDA Event, avg 20 runs)
Triton Wall time: 43.30 ms (total end-to-end, avg 20 runs)
Triton actions range: [-0.534057, 0.998606]
JAX actions range: [-0.533780, 1.001180]
Triton vs JAX MAE: 0.002332
Per-dimension MAE:
dim 0: 0.004500
dim 1: 0.004662
dim 2: 0.001603
dim 3: 0.000440
dim 4: 0.001742
dim 5: 0.000989
dim 6: 0.002390
=== PROMPT: put the red mug on the coffee machine ===
Triton GPU time: 43.29 ms (CUDA Event, avg 20 runs)
Triton Wall time: 43.27 ms (total end-to-end, avg 20 runs)
Triton actions range: [-0.355880, 0.999427]
JAX actions range: [-0.352955, 1.000402]
Triton vs JAX MAE: 0.003484
Per-dimension MAE:
dim 0: 0.002887
dim 1: 0.014394
dim 2: 0.002141
dim 3: 0.000637
dim 4: 0.001431
dim 5: 0.000826
dim 6: 0.002071
=== PROMPT: pick up the fork and place it on the plate ===
Triton GPU time: 43.41 ms (CUDA Event, avg 20 runs)
Triton Wall time: 43.63 ms (total end-to-end, avg 20 runs)
Triton actions range: [-0.497158, 0.997394]
JAX actions range: [-0.498947, 0.997480]
Triton vs JAX MAE: 0.002958
Per-dimension MAE:
dim 0: 0.003967
dim 1: 0.010362
dim 2: 0.002003
dim 3: 0.000344
dim 4: 0.001341
dim 5: 0.000599
dim 6: 0.002093
=== PROMPT: close the top drawer of the wooden cabinet ===
Triton GPU time: 43.43 ms (CUDA Event, avg 20 runs)
Triton Wall time: 43.43 ms (total end-to-end, avg 20 runs)
Triton actions range: [-0.380434, 0.995849]
JAX actions range: [-0.382129, 0.998579]
Triton vs JAX MAE: 0.001578
Per-dimension MAE:
dim 0: 0.003547
dim 1: 0.002827
dim 2: 0.001437
dim 3: 0.000338
dim 4: 0.000293
dim 5: 0.000294
dim 6: 0.002311
This precision discrepancy leads to a 2% drop in success rate on the LIBERO benchmark.
import numpy as np
import torch
import argparse
from PIL import Image
import einops
import json
import cv2
import os
import pickle
from pi05_infer import Pi05Inference
from openpi.training import config as _config
from openpi.policies import policy_config as _policy_config
class Pi05ModelEvaluator:
def __init__(self, task, model_type: str, triton_path: str, jax_path: str,
norm_stats_dir: str, config_name: str, prompt: str = "do something",
tokenizer_path: str = None, action_dim: int = 7):
self.triton_path = triton_path
self.jax_path = jax_path
self.task = task
self.model_type = model_type
self.norm_stats_dir = norm_stats_dir
self.config_name = config_name
self.prompt = prompt
self.tokenizer_path = tokenizer_path
self.action_dim = action_dim
self.policy = None
self.norm_stats = None
if self.model_type == "triton":
self.policy, self.norm_stats = self._load_triton_model()
elif self.model_type == "jax":
self.policy = self._load_jax_model()
# q01/q99 for state and actions
self._digitize_bins = np.linspace(-1, 1, 256 + 1)[:-1]
self._state_q01 = np.array(self.norm_stats["state"]["q01"]) if self.norm_stats else None
self._state_q99 = np.array(self.norm_stats["state"]["q99"]) if self.norm_stats else None
self._actions_q01 = np.array(self.norm_stats["actions"]["q01"]) if self.norm_stats else None
self._actions_q99 = np.array(self.norm_stats["actions"]["q99"]) if self.norm_stats else None
def _load_jax_model(self):
config = _config.get_config(self.config_name)
policy = _policy_config.create_trained_policy(config, self.jax_path)
return policy
def _load_triton_model(self):
with open(self.triton_path, 'rb') as f:
weights = pickle.load(f)
norm_stats = self._load_norm_stats(self.norm_stats_dir) if self.norm_stats_dir else None
policy = Pi05Inference(
checkpoint=weights,
num_views=2,
chunk_size=50,
tokenizer_path=self.tokenizer_path,
max_tokenize_len=200,
max_prompt_text=self.prompt,
discrete_state_input=False,
)
return policy, norm_stats
def _load_norm_stats(self, norm_stats_dir: str) -> dict:
norm_stats_path = os.path.join(norm_stats_dir, "norm_stats.json")
if os.path.exists(norm_stats_path):
with open(norm_stats_path, 'r') as f:
return json.load(f)['norm_stats']
return None
def _parse_image(self, image) -> np.ndarray:
image = np.asarray(image)
if np.issubdtype(image.dtype, np.floating):
image = (255 * image).astype(np.uint8)
if image.shape[0] == 3:
image = einops.rearrange(image, "c h w -> h w c")
return image
def _pad_to_dim(self, x: np.ndarray, target_dim: int, axis: int = -1) -> np.ndarray:
current_dim = x.shape[axis]
if current_dim < target_dim:
pad_width = [(0, 0)] * len(x.shape)
pad_width[axis] = (0, target_dim - current_dim)
return np.pad(x, pad_width)
return x
def _resize_with_pad(self, image: np.ndarray, height: int = 224, width: int = 224) -> np.ndarray:
pil_image = Image.fromarray(image)
cur_width, cur_height = pil_image.size
if cur_width == width and cur_height == height:
return image
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_image = pil_image.resize((resized_width, resized_height), resample=Image.BILINEAR)
zero_image = Image.new(resized_image.mode, (width, height), 0)
pad_height = max(0, int((height - resized_height) / 2))
pad_width = max(0, int((width - resized_width) / 2))
zero_image.paste(resized_image, (pad_width, pad_height))
return np.array(zero_image)
def _normalize_image(self, image: np.ndarray) -> np.ndarray:
return image.astype(np.float32) / 255.0 * 2.0 - 1.0
def _unnormalize_actions(self, actions: np.ndarray) -> np.ndarray:
q01 = self._pad_to_dim(self._actions_q01, 32)
q99 = self._pad_to_dim(self._actions_q99, 32)
return (actions + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01
def infer(self, inputs: dict, noise: np.ndarray) -> dict:
if self.model_type == "triton":
imgs = []
for view in ["base_0_rgb", "left_wrist_0_rgb"]:
img = self._parse_image(inputs[view])
img = self._resize_with_pad(img, 224, 224)
img = self._normalize_image(img)
imgs.append(torch.from_numpy(img))
observation_images = torch.stack(imgs, dim=0).to(torch.float32).cuda(non_blocking=True)
diffusion_noise = torch.from_numpy(noise).to(torch.float32).cuda(non_blocking=True)
# discrete_state_input=False: state_tokens 不使用,传 None
actions = self.policy.forward(observation_images, diffusion_noise, inputs.get("prompt"), None)
actions = actions.cpu().float().numpy()
actions = self._unnormalize_actions(actions)[:, :self.action_dim]
return {"actions": actions}
elif self.model_type == "jax":
actions = self.policy.infer(inputs, noise=noise)
actions["actions"] = actions["actions"][:, :self.action_dim]
return actions
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--triton_path', type=str, required=True)
parser.add_argument('--jax_path', type=str, required=True)
parser.add_argument('--norm_stats_dir', type=str, required=True)
parser.add_argument('--config_name', type=str, default='pi05_libero')
parser.add_argument('--prompt', type=str, default='pick up the fork')
parser.add_argument('--tokenizer_path', type=str, required=True)
parser.add_argument('--action_dim', type=int, default=7)
args = parser.parse_args()
example_image_global = cv2.imread("image1.png")
example_image_left = cv2.imread("image2.png")
example_image_right = cv2.imread("image3.png")
np.random.seed(42)
noise = np.random.randn(50, 32).astype(np.float32)
np.random.seed(42)
state = np.random.randn(8).astype(np.float32)
inputs_triton = {
"base_0_rgb": example_image_global,
"left_wrist_0_rgb": example_image_left,
"right_wrist_0_rgb": example_image_right,
"state": state,
"prompt": args.prompt,
}
inputs_jax = {
"observation/image": example_image_global,
"observation/wrist_image": example_image_left,
"observation/state": state,
"prompt": args.prompt,
}
common = dict(
task='check_consistency',
triton_path=args.triton_path,
jax_path=args.jax_path,
norm_stats_dir=args.norm_stats_dir,
config_name=args.config_name,
prompt=args.prompt,
tokenizer_path=args.tokenizer_path,
action_dim=args.action_dim,
)
# Triton
pi_triton = Pi05ModelEvaluator(model_type='triton', **common)
result_triton = pi_triton.infer(inputs_triton, noise)
print("=== Triton done ===")
# JAX
pi_jax = Pi05ModelEvaluator(model_type='jax', **common)
result_jax = pi_jax.infer(inputs_jax, noise)
print("=== JAX done ===")
t = result_triton['actions']
j = result_jax['actions']
mae = np.mean(np.abs(t - j))
print(f"Triton actions range: [{t.min():.6f}, {t.max():.6f}]")
print(f"JAX actions range: [{j.min():.6f}, {j.max():.6f}]")
print(f"Triton vs JAX MAE: {mae:.6f}")
print("Per-dimension MAE:")
for i in range(args.action_dim):
dim_mae = np.mean(np.abs(t[:, i] - j[:, i]))
print(f" dim {i}: {dim_mae:.6f}")
if __name__ == "__main__":
main()
Hello:
Thanks for your great work!
I have performed a precision alignment test between the JAX (Reference) and Triton (Inference) backends for the pi0.5 model on the LIBERO dataset. While the inference latency is optimized in Triton (~43ms), there is a noticeable Mean Absolute Error (MAE) in the action space, particularly concentrated in Dimension 1.
This discrepancy likely explains the performance drop observed when deploying the model via Triton-based inference engines (e.g., LeRobot) compared to the official JAX results.
This precision discrepancy leads to a 2% drop in success rate on the LIBERO benchmark.
is this result correct? How can I improve precision?
Here is my code:
Thanks a lot!