From 68b695ef3a72b8675aedece05363bdeed79fb294 Mon Sep 17 00:00:00 2001 From: Dave Date: Sat, 12 Aug 2023 13:36:27 +1000 Subject: [PATCH 1/5] FTR: Add support for MPS on Apple silicon --- chain_demo.py | 27 +++++++++++++++++---------- demo.py | 19 ++++++++++++++----- nets/pips.py | 4 ++-- saverloader.py | 4 ++-- 4 files changed, 35 insertions(+), 19 deletions(-) diff --git a/chain_demo.py b/chain_demo.py index 4c78727..ddf182d 100644 --- a/chain_demo.py +++ b/chain_demo.py @@ -18,8 +18,8 @@ random.seed(125) np.random.seed(125) -def run_model(model, rgbs, N, sw): - rgbs = rgbs.cuda().float() # B, S, C, H, W +def run_model(model, rgbs, N, sw, dev): + rgbs = rgbs.to(device).float() # B, S, C, H, W B, S, C, H, W = rgbs.shape rgbs_ = rgbs.reshape(B*S, C, H, W) @@ -31,17 +31,17 @@ def run_model(model, rgbs, N, sw): # try to pick a point on the dog, so we get an interesting trajectory # x = torch.randint(-10, 10, size=(1, N), device=torch.device('cuda')) + 468 # y = torch.randint(-10, 10, size=(1, N), device=torch.device('cuda')) + 118 - x = torch.ones((1, N), device=torch.device('cuda')) * 450.0 - y = torch.ones((1, N), device=torch.device('cuda')) * 100.0 + x = torch.ones((1, N), device=dev) * 450.0 + y = torch.ones((1, N), device=dev) * 100.0 xy0 = torch.stack([x, y], dim=-1) # B, N, 2 _, S, C, H, W = rgbs.shape - trajs_e = torch.zeros((B, S, N, 2), dtype=torch.float32, device='cuda') + trajs_e = torch.zeros((B, S, N, 2), dtype=torch.float32, device=dev) for n in range(N): # print('working on keypoint %d/%d' % (n+1, N)) cur_frame = 0 done = False - traj_e = torch.zeros((B, S, 2), dtype=torch.float32, device='cuda') + traj_e = torch.zeros((B, S, 2), dtype=torch.float32, device=dev) traj_e[:,0] = xy0[:,n] # B, 1, 2 # set first position feat_init = None while not done: @@ -51,7 +51,7 @@ def run_model(model, rgbs, N, sw): S_local = rgb_seq.shape[1] rgb_seq = torch.cat([rgb_seq, rgb_seq[:,-1].unsqueeze(1).repeat(1,8-S_local,1,1,1)], dim=1) - outs = model(traj_e[:,cur_frame].reshape(1, -1, 2), rgb_seq, iters=6, feat_init=feat_init, return_feat=True) + outs = model(traj_e[:,cur_frame].reshape(1, -1, 2), rgb_seq, device=dev, iters=6, feat_init=feat_init, return_feat=True) preds = outs[0] vis = outs[2] # B, S, 1 feat_init = outs[3] @@ -143,10 +143,17 @@ def main(): global_step = 0 - model = Pips(stride=4).cuda() + if torch.cuda.is_available(): + device = torch.device('cuda') + else if torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') + + model = Pips(stride=4).to(device) parameters = list(model.parameters()) if init_dir: - _ = saverloader.load(init_dir, model) + _ = saverloader.load(init_dir, model, device=device) global_step = 0 model.eval() @@ -179,7 +186,7 @@ def main(): iter_start_time = time.time() with torch.no_grad(): - trajs_e = run_model(model, rgbs, N, sw_t) + trajs_e = run_model(model, rgbs, N, sw_t, device) iter_time = time.time()-iter_start_time print('%s; step %06d/%d; rtime %.2f; itime %.2f' % ( diff --git a/demo.py b/demo.py index f3b6183..8a872c8 100644 --- a/demo.py +++ b/demo.py @@ -18,8 +18,8 @@ random.seed(125) np.random.seed(125) -def run_model(model, rgbs, N, sw): - rgbs = rgbs.cuda().float() # B, S, C, H, W +def run_model(model, rgbs, N, sw, device): + rgbs = rgbs.to(device).float() # B, S, C, H, W B, S, C, H, W = rgbs.shape rgbs_ = rgbs.reshape(B*S, C, H, W) @@ -30,7 +30,7 @@ def run_model(model, rgbs, N, sw): # pick N points to track; we'll use a uniform grid N_ = np.sqrt(N).round().astype(np.int32) - grid_y, grid_x = utils.basic.meshgrid2d(B, N_, N_, stack=False, norm=False, device='cuda') + grid_y, grid_x = utils.basic.meshgrid2d(B, N_, N_, stack=False, norm=False, device=device) grid_y = 8 + grid_y.reshape(B, -1)/float(N_-1) * (H-16) grid_x = 8 + grid_x.reshape(B, -1)/float(N_-1) * (W-16) xy = torch.stack([grid_x, grid_y], dim=-1) # B, N_*N_, 2 @@ -111,10 +111,19 @@ def main(): global_step = 0 - model = Pips(stride=4).cuda() + # Pick a device + if torch.cuda.is_available(): + device = torch.device('cuda') + else if torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') + + + model = Pips(stride=4).to(device) parameters = list(model.parameters()) if init_dir: - _ = saverloader.load(init_dir, model) + _ = saverloader.load(init_dir, model, device=device) global_step = 0 model.eval() diff --git a/nets/pips.py b/nets/pips.py index bc9ff4e..903fc54 100644 --- a/nets/pips.py +++ b/nets/pips.py @@ -425,8 +425,8 @@ def __init__(self, S=8, stride=8): nn.Linear(self.latent_dim, 1), ) - def forward(self, xys, rgbs, coords_init=None, feat_init=None, iters=3, trajs_g=None, vis_g=None, valids=None, sw=None, return_feat=False, is_train=False): - total_loss = torch.tensor(0.0).cuda() + def forward(self, xys, rgbs, device, coords_init=None, feat_init=None, iters=3, trajs_g=None, vis_g=None, valids=None, sw=None, return_feat=False, is_train=False): + total_loss = torch.tensor(0.0).to(dev) B, N, D = xys.shape assert(D==2) diff --git a/saverloader.py b/saverloader.py index 8f66ab8..def61db 100644 --- a/saverloader.py +++ b/saverloader.py @@ -22,7 +22,7 @@ def save(ckpt_dir, optimizer, model, global_step, scheduler=None, model_ema=None torch.save(ckpt, model_path) print("saved a checkpoint: %s" % (model_path)) -def load(ckpt_dir, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None): +def load(ckpt_dir, model, device, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None): print('reading ckpt from %s' % ckpt_dir) if not os.path.exists(ckpt_dir): print('...there is no full checkpoint here!') @@ -55,7 +55,7 @@ def load(ckpt_dir, model, optimizer=None, scheduler=None, model_ema=None, step=0 # 3. load the new state dict model.load_state_dict(model_dict, strict=False) else: - checkpoint = torch.load(path) + checkpoint = torch.load(path, map_location=device) model.load_state_dict(checkpoint['model_state_dict'], strict=False) if optimizer is not None: From 32598cc613d91cc38b0fb74cd263f98769a50ba8 Mon Sep 17 00:00:00 2001 From: Dave Date: Sat, 12 Aug 2023 14:03:12 +1000 Subject: [PATCH 2/5] DOC: Updated the README --- README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/README.md b/README.md index b9ed13a..7221164 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,22 @@ This is the official code release for our ECCV 2022 paper, "Particle Video Revis +### Update 08/12/23: +Added support for Apples silicon. +Set up as below but instead of: + +`conda install pytorch=1.12.0 torchvision=0.13.0 cudatoolkit=11.3 -c pytorch` + +use + +`conda install pytorch torchvision torchaudio -c pytorch` + +At the time of this fix, MPS does not support all operatord required for this code to run so you will also need : + +`export PYTORCH_ENABLE_MPS_FALLBACK=1` + +I also found I had to `conda install python=3.9` as the 3.11 I had didn't work. + ### Update 07/27/23: Very soon, this repo will also host the PIPs++ model from our ICCV 2023 paper, "PointOdyssey: A Large-Scale Synthetic Dataset for Long-Term Point Tracking". **[[Paper](https://arxiv.org/abs/2307.15055)] [[Project Page](https://pointodyssey.com/)]** From d404db3525dc5f7c2068a92edc3a4a302f668428 Mon Sep 17 00:00:00 2001 From: Dave Date: Sat, 12 Aug 2023 14:04:00 +1000 Subject: [PATCH 3/5] FIX: Updated scripts to accept device. --- chain_demo.py | 14 +++++++------- demo.py | 6 +++--- nets/pips.py | 2 +- saverloader.py | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/chain_demo.py b/chain_demo.py index ddf182d..7beb152 100644 --- a/chain_demo.py +++ b/chain_demo.py @@ -18,7 +18,7 @@ random.seed(125) np.random.seed(125) -def run_model(model, rgbs, N, sw, dev): +def run_model(model, rgbs, N, sw, device): rgbs = rgbs.to(device).float() # B, S, C, H, W B, S, C, H, W = rgbs.shape @@ -31,17 +31,17 @@ def run_model(model, rgbs, N, sw, dev): # try to pick a point on the dog, so we get an interesting trajectory # x = torch.randint(-10, 10, size=(1, N), device=torch.device('cuda')) + 468 # y = torch.randint(-10, 10, size=(1, N), device=torch.device('cuda')) + 118 - x = torch.ones((1, N), device=dev) * 450.0 - y = torch.ones((1, N), device=dev) * 100.0 + x = torch.ones((1, N), device=device) * 450.0 + y = torch.ones((1, N), device=device) * 100.0 xy0 = torch.stack([x, y], dim=-1) # B, N, 2 _, S, C, H, W = rgbs.shape - trajs_e = torch.zeros((B, S, N, 2), dtype=torch.float32, device=dev) + trajs_e = torch.zeros((B, S, N, 2), dtype=torch.float32, device=device) for n in range(N): # print('working on keypoint %d/%d' % (n+1, N)) cur_frame = 0 done = False - traj_e = torch.zeros((B, S, 2), dtype=torch.float32, device=dev) + traj_e = torch.zeros((B, S, 2), dtype=torch.float32, device=device) traj_e[:,0] = xy0[:,n] # B, 1, 2 # set first position feat_init = None while not done: @@ -51,7 +51,7 @@ def run_model(model, rgbs, N, sw, dev): S_local = rgb_seq.shape[1] rgb_seq = torch.cat([rgb_seq, rgb_seq[:,-1].unsqueeze(1).repeat(1,8-S_local,1,1,1)], dim=1) - outs = model(traj_e[:,cur_frame].reshape(1, -1, 2), rgb_seq, device=dev, iters=6, feat_init=feat_init, return_feat=True) + outs = model(traj_e[:,cur_frame].reshape(1, -1, 2), rgb_seq, device, iters=6, feat_init=feat_init, return_feat=True) preds = outs[0] vis = outs[2] # B, S, 1 feat_init = outs[3] @@ -145,7 +145,7 @@ def main(): if torch.cuda.is_available(): device = torch.device('cuda') - else if torch.backends.mps.is_available(): + elif torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') diff --git a/demo.py b/demo.py index 8a872c8..8f994c3 100644 --- a/demo.py +++ b/demo.py @@ -37,7 +37,7 @@ def run_model(model, rgbs, N, sw, device): _, S, C, H, W = rgbs.shape print_stats('rgbs', rgbs) - preds, preds_anim, vis_e, stats = model(xy, rgbs, iters=6) + preds, preds_anim, vis_e, stats = model(xy, rgbs, device, iters=6, ) trajs_e = preds[-1] print_stats('trajs_e', trajs_e) @@ -114,7 +114,7 @@ def main(): # Pick a device if torch.cuda.is_available(): device = torch.device('cuda') - else if torch.backends.mps.is_available(): + elif torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') @@ -156,7 +156,7 @@ def main(): iter_start_time = time.time() with torch.no_grad(): - trajs_e = run_model(model, rgbs, N, sw_t) + trajs_e = run_model(model, rgbs, N, sw_t, device) iter_time = time.time()-iter_start_time print('%s; step %06d/%d; rtime %.2f; itime %.2f' % ( diff --git a/nets/pips.py b/nets/pips.py index 903fc54..a8936fd 100644 --- a/nets/pips.py +++ b/nets/pips.py @@ -426,7 +426,7 @@ def __init__(self, S=8, stride=8): ) def forward(self, xys, rgbs, device, coords_init=None, feat_init=None, iters=3, trajs_g=None, vis_g=None, valids=None, sw=None, return_feat=False, is_train=False): - total_loss = torch.tensor(0.0).to(dev) + total_loss = torch.tensor(0.0).to(device) B, N, D = xys.shape assert(D==2) diff --git a/saverloader.py b/saverloader.py index def61db..a8a55a4 100644 --- a/saverloader.py +++ b/saverloader.py @@ -55,7 +55,7 @@ def load(ckpt_dir, model, device, optimizer=None, scheduler=None, model_ema=None # 3. load the new state dict model.load_state_dict(model_dict, strict=False) else: - checkpoint = torch.load(path, map_location=device) + checkpoint = torch.load(path, map_location='mps') model.load_state_dict(checkpoint['model_state_dict'], strict=False) if optimizer is not None: From f0137bfbf9b129a3df059c7ffa9d801ddc608940 Mon Sep 17 00:00:00 2001 From: Dave Date: Sat, 12 Aug 2023 14:08:30 +1000 Subject: [PATCH 4/5] DOC: Retructure readme --- README.md | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 7221164..72bc37d 100644 --- a/README.md +++ b/README.md @@ -4,22 +4,6 @@ This is the official code release for our ECCV 2022 paper, "Particle Video Revis -### Update 08/12/23: -Added support for Apples silicon. -Set up as below but instead of: - -`conda install pytorch=1.12.0 torchvision=0.13.0 cudatoolkit=11.3 -c pytorch` - -use - -`conda install pytorch torchvision torchaudio -c pytorch` - -At the time of this fix, MPS does not support all operatord required for this code to run so you will also need : - -`export PYTORCH_ENABLE_MPS_FALLBACK=1` - -I also found I had to `conda install python=3.9` as the 3.11 I had didn't work. - ### Update 07/27/23: Very soon, this repo will also host the PIPs++ model from our ICCV 2023 paper, "PointOdyssey: A Large-Scale Synthetic Dataset for Long-Term Point Tracking". **[[Paper](https://arxiv.org/abs/2307.15055)] [[Project Page](https://pointodyssey.com/)]** @@ -67,6 +51,24 @@ conda install pip pip install -r requirements.txt ``` +## Running on Apple Silicon +If you want to run this code on Apple's silicon, use the following: + +``` +conda create --name pips +source activate pips +conda install python=3.9 +conda install pytorch torchvision torchaudio -c pytorch +conda install pip +pip install -r requirements.txt +``` + +At the time of this fix, MPS does not support all operators required for this code to run so you will also need: + +`export PYTORCH_ENABLE_MPS_FALLBACK=1` + + + ## Demo To download our reference model, download the weights from [Hugging Face. ![](https://img.shields.io/badge/🤗%20Hugging%20Face-Model-blue)](https://huggingface.co/aharley/pips) From af9c0cdcce776fcfba574e6b1ed41765b7513e61 Mon Sep 17 00:00:00 2001 From: Dave Date: Sat, 12 Aug 2023 14:17:11 +1000 Subject: [PATCH 5/5] DOCS: Simplified setup of conda env. --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 72bc37d..ada7a02 100644 --- a/README.md +++ b/README.md @@ -55,9 +55,8 @@ pip install -r requirements.txt If you want to run this code on Apple's silicon, use the following: ``` -conda create --name pips +conda create --name pips python=3.9 source activate pips -conda install python=3.9 conda install pytorch torchvision torchaudio -c pytorch conda install pip pip install -r requirements.txt