My converted checkpoint does not produce the same action as the action generated by the original checkpoint. I also see this issue #19 which has similar problem. But when I use the script in issue19, it still dosen't work.
Triton actions range: [-0.992531, 1.050355]
JAX actions range: [-1.000153, 1.015160]
Triton vs JAX MAE: 0.110080
Per-dimension MAE:
Joint 1: 0.204920
Joint 2: 0.090841
Joint 3: 0.258178
Joint 4: 0.059499
Joint 5: 0.106889
Joint 6: 0.029172
Gripper Width: 0.021063
import numpy as np
import torch
import argparse
from PIL import Image
import einops
import json
import cv2
import os
from dexmal.pi0_infer import Pi0Inference
from openpi.training import config as _config
from openpi.policies import policy_config as _policy_config
import pickle
class Pi0ModelEvaluator:
def __init__(self, task, model_type: str, triton_path: str, jax_path: str, norm_stats_dir:str, config_name: str):
self.triton_path = triton_path
self.jax_path = jax_path
self.task = task
self.model_type = model_type
self.norm_stats_dir = norm_stats_dir
self.policy = None
self.config_name = config_name
self.norm_stats = None
self.results = {
'episode_results': []
}
if self.model_type == "triton":
self.policy, self.norm_stats = self._load_model(model_type=self.model_type)
elif self.model_type == "jax":
self.policy = self._load_jax_model()
def _load_jax_model(self):
config = _config.get_config(self.config_name)
policy = _policy_config.create_trained_policy(config, self.jax_path)
return policy
def _load_model(self, model_type="triton"):
with open(self.triton_path, 'rb') as f:
weights = pickle.load(f)
policy = Pi0Inference(checkpoint=weights, num_views=2, chunk_size=50)
norm_stats = self._load_norm_stats(self.norm_stats_dir) if self.norm_stats_dir else None
return policy, norm_stats
def _load_norm_stats(self, norm_stats_dir: str) -> dict:
norm_stats_path = os.path.join(norm_stats_dir, "norm_stats.json")
if os.path.exists(norm_stats_path):
with open(norm_stats_path, 'r') as f:
return json.load(f)['norm_stats']
return None
def _parse_image(self, image) -> np.ndarray:
image = np.asarray(image)
if np.issubdtype(image.dtype, np.floating):
image = (255 * image).astype(np.uint8)
if image.shape[0] == 3:
image = einops.rearrange(image, "c h w -> h w c")
return image
def _pad_to_dim(self, x: np.ndarray, target_dim: int, axis: int = -1) -> np.ndarray:
current_dim = x.shape[axis]
if current_dim < target_dim:
pad_width = [(0, 0)] * len(x.shape)
pad_width[axis] = (0, target_dim - current_dim)
return np.pad(x, pad_width)
return x
def _resize_with_pad(self, image: np.ndarray, height: int = 224, width: int = 224) -> np.ndarray:
pil_image = Image.fromarray(image)
cur_width, cur_height = pil_image.size
if cur_width == width and cur_height == height:
return image
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_image = pil_image.resize((resized_width, resized_height), resample=Image.BILINEAR)
zero_image = Image.new(resized_image.mode, (width, height), 0)
pad_height = max(0, int((height - resized_height) / 2))
pad_width = max(0, int((width - resized_width) / 2))
zero_image.paste(resized_image, (pad_width, pad_height))
return np.array(zero_image)
def _normalize_image(self, image: np.ndarray) -> np.ndarray:
image = image.astype(np.float32) / 255.0 * 2.0 - 1.0
return image
def _normalize_state(self, state: np.ndarray, norm_stats: dict) -> np.ndarray:
if norm_stats and "state" in norm_stats:
state_mean = np.array(norm_stats["state"]["mean"])
state_mean = self._pad_to_dim(state_mean, 32)
state_std = np.array(norm_stats["state"]["std"])
state_std = self._pad_to_dim(state_std, 32)
return (state - state_mean) / (state_std + 1e-6)
return None
def _unnormalize_state(self, actions: np.ndarray, norm_stats: dict) -> np.ndarray:
if norm_stats and "actions" in norm_stats:
actions_mean = np.array(norm_stats["actions"]["mean"])
actions_mean = self._pad_to_dim(actions_mean, 32)
actions_std = np.array(norm_stats["actions"]["std"])
actions_std = self._pad_to_dim(actions_std, 32)
return actions * (actions_std + 1e-6) + actions_mean
return None
def _apply_input_transforms(self, data: dict, action_dim: int = 32, norm_stats: dict = None) -> dict:
state = self._pad_to_dim(data["state"], action_dim)
state = self._normalize_state(state, norm_stats)
base_image = self._parse_image(data["base_0_rgb"])
left_wrist_image = self._parse_image(data["left_wrist_0_rgb"])
right_wrist_image = self._parse_image(data["right_wrist_0_rgb"])
base_image = self._resize_with_pad(base_image, 224, 224)
base_image = self._normalize_image(base_image)
left_wrist_image = self._resize_with_pad(left_wrist_image, 224, 224)
left_wrist_image = self._normalize_image(left_wrist_image)
right_wrist_image = self._resize_with_pad(right_wrist_image, 224, 224)
right_wrist_image = self._normalize_image(right_wrist_image)
image_dict = {
"base_0_rgb": base_image,
"left_wrist_0_rgb": left_wrist_image,
"right_wrist_0_rgb": right_wrist_image,
}
image_mask_dict = {
"base_0_rgb": np.True_,
"left_wrist_0_rgb": np.False_,
"right_wrist_0_rgb": np.False_,
}
inputs = {
"state": state,
"image": image_dict,
"image_mask": image_mask_dict,
}
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
def infer(self, inputs: dict, noise: np.ndarray) -> dict:
if self.model_type == "triton":
transformed_inputs = self._apply_input_transforms(inputs, action_dim=32, norm_stats=self.norm_stats)
images = []
for view in ["base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb"]:
img = transformed_inputs["image"][view]
images.append(torch.from_numpy(img))
observation_images = torch.stack(images[0:2], dim=0).to(torch.float32)
observation_state = torch.from_numpy(transformed_inputs["state"].astype(np.float32)).unsqueeze(0).to(torch.float32)
normalized_observation_images = observation_images.cuda()
normalized_observation_state = observation_state.cuda()
diffusion_noise = torch.from_numpy(noise).to(torch.float32).cuda()
self.policy.buffers['observation_images_normalized'].copy_(normalized_observation_images)
self.policy.buffers['observation_state_normalized'].copy_(normalized_observation_state.squeeze(0))
self.policy.buffers['diffusion_noise'].copy_(diffusion_noise)
self.policy.record_run()
actions = self.policy.buffers['diffusion_noise']
actions = actions.cpu().float().numpy()
actions = self._unnormalize_state(actions, self.norm_stats)[:, :7]
actions[..., :6] = actions[..., :6] + inputs["state"][..., :6]
# actions = actions[:, :7]
return {
"actions": actions
}
elif self.model_type == "jax":
actions = self.policy.infer(inputs, noise=noise)
# actions["actions"] = actions["actions"][:, :7]
return actions
def main():
parser = argparse.ArgumentParser(description="VLA Model Inference")
parser.add_argument('--triton_path', type=str, default='~/workspace/ckpt/pi0_libero.pkl ')
parser.add_argument('--jax_path', type=str, default='~/.cache/openpi/openpi-assets/checkpoints/pi0_libero')
parser.add_argument('--norm_stats_dir', type=str, default='~/.cache/openpi/openpi-assets/checkpoints/pi0_libero/assets/physical-intelligence/libero/')
parser.add_argument('--config_name', type=str, default='pi0_libero')
args = parser.parse_args()
example_image_global = cv2.imread("~/workspace/data/1.png")
example_image_hand = cv2.imread("~/workspace/data/2.png")
noise1 = torch.tensor([[-1.2578, -0.4023, -1.1250, 0.8789, -0.8633, 0.3457, 0.9414, -0.1226,
-1.1875, 0.0713, -1.7578, -1.0234, 1.0312, -0.0031, 0.0167, 0.0422,
-2.5938, 0.0087, 0.8398, -0.0830, 1.0156, -0.1553, -0.0806, 0.2988,
1.0312, -1.3281, -0.9062, 0.2754, 0.9336, -0.0457, 0.0757, -0.0317]], device='cuda:0', dtype=torch.bfloat16)
noises = noise1.expand(50, -1)
noise = noises.float().cpu().numpy()
noise = np.array(noise, dtype=np.float32)
state = np.array([-0.2355, 0.1189, 1.1514, 3.1155, 0.0199, -0.1378, 0.0395, -0.0393], dtype=np.float32)
inputs_triton = {
"base_0_rgb": example_image_global,
"left_wrist_0_rgb": example_image_hand,
"right_wrist_0_rgb": example_image_hand,
"state": state,
"prompt": "Sort bowls and paper cups into their designated places"
}
inputs_jax = {
"observation/image": example_image_global,
"observation/wrist_image": example_image_hand,
"observation/state": state,
"prompt": "Sort bowls and paper cups into their designated places"
}
pi0_triton = Pi0ModelEvaluator(task='check_consistency', model_type="triton", triton_path=args.triton_path, jax_path=args.jax_path,
norm_stats_dir=args.norm_stats_dir, config_name=args.config_name)
result_triton = pi0_triton.infer(inputs_triton, noise)
del pi0_triton
print("finish triton")
pi0_jax = Pi0ModelEvaluator(task='check_consistency', model_type="jax", triton_path=args.triton_path, jax_path=args.jax_path,
norm_stats_dir=args.norm_stats_dir, config_name=args.config_name)
result_jax = pi0_jax.infer(inputs_jax, noise)
del pi0_jax
print("finish jax")
print(f"Triton actions range: [{result_triton['actions'].min():.6f}, {result_triton['actions'].max():.6f}]")
print(f"JAX actions range: [{result_jax['actions'].min():.6f}, {result_jax['actions'].max():.6f}]")
joint_names = ['Joint 1', 'Joint 2', 'Joint 3', 'Joint 4', 'Joint 5', 'Joint 6', 'Gripper Width']
triton_jax_mae = np.mean(np.abs(result_triton['actions'] - result_jax['actions']))
print(f"Triton vs JAX MAE: {triton_jax_mae:.6f}")
print("Per-dimension MAE:")
for i in range(7):
dim_mae = np.mean(np.abs(result_triton['actions'][:, i] - result_jax['actions'][:, i]))
print(f"{joint_names[i]}: {dim_mae:.6f}")
print("Triton actions:", result_triton['actions'][:3])
print("JAX actions:", result_jax['actions'][:3])
if __name__ == "__main__":
main()
Is there any way to verify where the issue occurred? Greatly appreciate for any guidance.
Thanks
Dear authors,
thank you for providing the source code of realtime-vla.
My converted checkpoint does not produce the same action as the action generated by the original checkpoint. I also see this issue #19 which has similar problem. But when I use the script in issue19, it still dosen't work.
I used pi0_libero checkpoint:
Then convert to triton based:
Test script:
What's more, I add two lines to handle noise data format in pi0.py :
The library versions(and hardware) used in the comparison are:
torch: 2.8.0
triton: 3.4.0,
CUDA: 12.6,
GPU: RTX5090
Is there any way to verify where the issue occurred? Greatly appreciate for any guidance.
Thanks