Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions conda_environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ dependencies:
- -e git+https://github.com/ARISE-Initiative/robomimic@main#egg=robomimic
- diffusers
- zarr
- h5py
- robomimic
- diffusers
- zarr
- einops
- tqdm
- pybullet
Expand Down
1 change: 0 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def eval_main(cfg):
# run policy in environment
success_count = 0
for i in range(cfg.num_episodes):
runner.reset()
rewards, info = runner.run(agent, cfg.max_steps)
assert "success" in info, "info['success'] not returned in info from runner"
print(f"info: {info}")
Expand Down
12 changes: 0 additions & 12 deletions imitation/config/policy/robomimic_eef.yaml

This file was deleted.

4 changes: 2 additions & 2 deletions imitation/dataset/robomimic_eef_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(self,
self.indices = []
self.data_at_indices = []
# if indices file exists, load it
index_file = dataset_path.replace(".hdf5", f"_indices_{obs_horizon}_{action_horizon}_{pred_horizon}.npy")
data_at_indices_file = dataset_path.replace(".hdf5", f"_data_at_indices_{obs_horizon}_{action_horizon}_{pred_horizon}.npy")
index_file = dataset_path.replace(".hdf5", f"_eef_indices_{obs_horizon}_{action_horizon}_{pred_horizon}.npy")
data_at_indices_file = dataset_path.replace(".hdf5", f"_eef_data_at_indices_{obs_horizon}_{action_horizon}_{pred_horizon}.npy")
if os.path.exists(index_file):
self.indices = np.load(index_file)
self.data_at_indices = np.load(data_at_indices_file, allow_pickle=True)
Expand Down
5 changes: 3 additions & 2 deletions imitation/dataset/robomimic_lowdim_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ def create_sample_indices(self):
|------------ pred_horizon -------------|
'''
idx_global = 0
n_latency_steps = 0
for key in tqdm(self.dataset_keys):
episode_length = len(self.dataset_root[f"data/{key}/obs/{self.obs_keys[0]}"])
for idx in range(episode_length - self.pred_horizon):
for idx in range(episode_length - self.pred_horizon + n_latency_steps):
if idx - self.obs_horizon < 0:
continue
self.indices.append(idx_global + idx)
Expand All @@ -100,7 +101,7 @@ def create_sample_indices(self):
data_obs_keys.append(obs)
data_action_keys = []
for action_key in self.action_keys:
action = self.dataset_root[f"data/{key}/obs/{action_key}"][idx:idx+self.pred_horizon, :]
action = self.dataset_root[f"data/{key}/obs/{action_key}"][idx + n_latency_steps:idx + n_latency_steps + self.pred_horizon, :]
if "quat" in action_key:
action = self.rotation_transformer.forward(action)
data_action_keys.append(action)
Expand Down
3 changes: 2 additions & 1 deletion imitation/env/robomimic_lowdim_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ def _robosuite_obs_to_robomimic_obs(self, obs):
j = i*39
# 7 - sin of joint angles
robot_joint_pos = obs[j:j + 7]
# 7 - sin of joint angles
# 7 - sin of joint angles
# robot_joint_sin = obs[j + 7:j + 14]
# 7 - cos of joint angles
# robot_joint_cos = obs[j + 14:j + 21]
# 7 - joint velocities
# robot_joint_vel = obs[j + 21:j + 28]
eef_pose = obs[j + 28:j + 31]
eef_quat = obs[j + 31:j + 35]
eef_6d = self.rotation_transformer.forward(eef_quat)
Expand Down
5 changes: 3 additions & 2 deletions imitation/env_runner/robomimic_lowdim_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def reset(self) -> None:
self.obs_deque = collections.deque(
[self.obs] * self.obs_horizon, maxlen=self.obs_horizon)

def run(self, agent: BaseAgent, n_steps: int) -> Dict:
def run(self, agent: BaseAgent, n_steps: int = 100) -> Dict:
log.info(f"Running agent {agent.__class__.__name__} for {n_steps} steps")
self.reset()
if self.output_video:
self.start_video()
done = False
Expand All @@ -87,12 +88,12 @@ def run(self, agent: BaseAgent, n_steps: int) -> Dict:
if self.output_video:
self.end_video()
return rewards, info

obs, reward, done, info = self.env.step(action)
self.obs_deque.append(obs)

if self.render:
self.env.render()
# time.sleep(1/self.fps) # TODO properly fix the rendering speed or not

if self.output_video:
# We need to directly grab full observations so we can get image data
Expand Down
29 changes: 11 additions & 18 deletions imitation/policy/diffusion_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tqdm.auto import tqdm
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import wandb

Expand Down Expand Up @@ -159,13 +160,7 @@ def get_action(self, obs_seq):
# only take action_horizon number of actions
action = action_pred[:self.action_horizon,:]
# (action_horizon, action_dim)
return action # TODO limit this in runner

def validate(self, dataset=None, model_path=None):
'''
Calculate validation loss for noise prediction model in the given dataset
'''
return None
return action

def train(self,
dataset=None,
Expand All @@ -188,10 +183,9 @@ def train(self,
# accelerates training and improves stability
# holds a copy of the model weights

# TODO use EMA
# ema = EMAModel(
# model=noise_pred_net,
# power=0.75)
ema = EMAModel(
parameters=self.noise_pred_net.parameters(),
power=0.75)

# Standard ADAM optimizer
# Note that EMA parameters are not optimized
Expand Down Expand Up @@ -262,9 +256,8 @@ def train(self,
# this is different from standard pytorch behavior
lr_scheduler.step()

# TODO use EMA
# update Exponential Moving Average of the model weights
# ema.step(noise_pred_net)
ema.step(self.noise_pred_net.parameters())


# logging
Expand All @@ -273,10 +266,10 @@ def train(self,
tepoch.set_postfix(loss=loss_cpu)
tglobal.set_postfix(loss=np.mean(epoch_loss))
wandb.log({'epoch_loss': np.mean(epoch_loss)})
# Weights of the EMA model are used for inference
ema_noise_pred_net = self.noise_pred_net
ema.copy_to(ema_noise_pred_net.parameters())
# save model checkpoint
torch.save(self.noise_pred_net.state_dict(), model_path)
torch.save(ema_noise_pred_net.state_dict(), model_path)

# Weights of the EMA model
# is used for inference
# ema_noise_pred_net = ema.averaged_model
self.ema_noise_pred_net = self.noise_pred_net