diff --git a/README.md b/README.md index b9ed13a..ada7a02 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,23 @@ conda install pip pip install -r requirements.txt ``` +## Running on Apple Silicon +If you want to run this code on Apple's silicon, use the following: + +``` +conda create --name pips python=3.9 +source activate pips +conda install pytorch torchvision torchaudio -c pytorch +conda install pip +pip install -r requirements.txt +``` + +At the time of this fix, MPS does not support all operators required for this code to run so you will also need: + +`export PYTORCH_ENABLE_MPS_FALLBACK=1` + + + ## Demo To download our reference model, download the weights from [Hugging Face. ![](https://img.shields.io/badge/🤗%20Hugging%20Face-Model-blue)](https://huggingface.co/aharley/pips) diff --git a/chain_demo.py b/chain_demo.py index 4c78727..7beb152 100644 --- a/chain_demo.py +++ b/chain_demo.py @@ -18,8 +18,8 @@ random.seed(125) np.random.seed(125) -def run_model(model, rgbs, N, sw): - rgbs = rgbs.cuda().float() # B, S, C, H, W +def run_model(model, rgbs, N, sw, device): + rgbs = rgbs.to(device).float() # B, S, C, H, W B, S, C, H, W = rgbs.shape rgbs_ = rgbs.reshape(B*S, C, H, W) @@ -31,17 +31,17 @@ def run_model(model, rgbs, N, sw): # try to pick a point on the dog, so we get an interesting trajectory # x = torch.randint(-10, 10, size=(1, N), device=torch.device('cuda')) + 468 # y = torch.randint(-10, 10, size=(1, N), device=torch.device('cuda')) + 118 - x = torch.ones((1, N), device=torch.device('cuda')) * 450.0 - y = torch.ones((1, N), device=torch.device('cuda')) * 100.0 + x = torch.ones((1, N), device=device) * 450.0 + y = torch.ones((1, N), device=device) * 100.0 xy0 = torch.stack([x, y], dim=-1) # B, N, 2 _, S, C, H, W = rgbs.shape - trajs_e = torch.zeros((B, S, N, 2), dtype=torch.float32, device='cuda') + trajs_e = torch.zeros((B, S, N, 2), dtype=torch.float32, device=device) for n in range(N): # print('working on keypoint %d/%d' % (n+1, N)) cur_frame = 0 done = False - traj_e = torch.zeros((B, S, 2), dtype=torch.float32, device='cuda') + traj_e = torch.zeros((B, S, 2), dtype=torch.float32, device=device) traj_e[:,0] = xy0[:,n] # B, 1, 2 # set first position feat_init = None while not done: @@ -51,7 +51,7 @@ def run_model(model, rgbs, N, sw): S_local = rgb_seq.shape[1] rgb_seq = torch.cat([rgb_seq, rgb_seq[:,-1].unsqueeze(1).repeat(1,8-S_local,1,1,1)], dim=1) - outs = model(traj_e[:,cur_frame].reshape(1, -1, 2), rgb_seq, iters=6, feat_init=feat_init, return_feat=True) + outs = model(traj_e[:,cur_frame].reshape(1, -1, 2), rgb_seq, device, iters=6, feat_init=feat_init, return_feat=True) preds = outs[0] vis = outs[2] # B, S, 1 feat_init = outs[3] @@ -143,10 +143,17 @@ def main(): global_step = 0 - model = Pips(stride=4).cuda() + if torch.cuda.is_available(): + device = torch.device('cuda') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') + + model = Pips(stride=4).to(device) parameters = list(model.parameters()) if init_dir: - _ = saverloader.load(init_dir, model) + _ = saverloader.load(init_dir, model, device=device) global_step = 0 model.eval() @@ -179,7 +186,7 @@ def main(): iter_start_time = time.time() with torch.no_grad(): - trajs_e = run_model(model, rgbs, N, sw_t) + trajs_e = run_model(model, rgbs, N, sw_t, device) iter_time = time.time()-iter_start_time print('%s; step %06d/%d; rtime %.2f; itime %.2f' % ( diff --git a/demo.py b/demo.py index f3b6183..8f994c3 100644 --- a/demo.py +++ b/demo.py @@ -18,8 +18,8 @@ random.seed(125) np.random.seed(125) -def run_model(model, rgbs, N, sw): - rgbs = rgbs.cuda().float() # B, S, C, H, W +def run_model(model, rgbs, N, sw, device): + rgbs = rgbs.to(device).float() # B, S, C, H, W B, S, C, H, W = rgbs.shape rgbs_ = rgbs.reshape(B*S, C, H, W) @@ -30,14 +30,14 @@ def run_model(model, rgbs, N, sw): # pick N points to track; we'll use a uniform grid N_ = np.sqrt(N).round().astype(np.int32) - grid_y, grid_x = utils.basic.meshgrid2d(B, N_, N_, stack=False, norm=False, device='cuda') + grid_y, grid_x = utils.basic.meshgrid2d(B, N_, N_, stack=False, norm=False, device=device) grid_y = 8 + grid_y.reshape(B, -1)/float(N_-1) * (H-16) grid_x = 8 + grid_x.reshape(B, -1)/float(N_-1) * (W-16) xy = torch.stack([grid_x, grid_y], dim=-1) # B, N_*N_, 2 _, S, C, H, W = rgbs.shape print_stats('rgbs', rgbs) - preds, preds_anim, vis_e, stats = model(xy, rgbs, iters=6) + preds, preds_anim, vis_e, stats = model(xy, rgbs, device, iters=6, ) trajs_e = preds[-1] print_stats('trajs_e', trajs_e) @@ -111,10 +111,19 @@ def main(): global_step = 0 - model = Pips(stride=4).cuda() + # Pick a device + if torch.cuda.is_available(): + device = torch.device('cuda') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') + + + model = Pips(stride=4).to(device) parameters = list(model.parameters()) if init_dir: - _ = saverloader.load(init_dir, model) + _ = saverloader.load(init_dir, model, device=device) global_step = 0 model.eval() @@ -147,7 +156,7 @@ def main(): iter_start_time = time.time() with torch.no_grad(): - trajs_e = run_model(model, rgbs, N, sw_t) + trajs_e = run_model(model, rgbs, N, sw_t, device) iter_time = time.time()-iter_start_time print('%s; step %06d/%d; rtime %.2f; itime %.2f' % ( diff --git a/nets/pips.py b/nets/pips.py index bc9ff4e..a8936fd 100644 --- a/nets/pips.py +++ b/nets/pips.py @@ -425,8 +425,8 @@ def __init__(self, S=8, stride=8): nn.Linear(self.latent_dim, 1), ) - def forward(self, xys, rgbs, coords_init=None, feat_init=None, iters=3, trajs_g=None, vis_g=None, valids=None, sw=None, return_feat=False, is_train=False): - total_loss = torch.tensor(0.0).cuda() + def forward(self, xys, rgbs, device, coords_init=None, feat_init=None, iters=3, trajs_g=None, vis_g=None, valids=None, sw=None, return_feat=False, is_train=False): + total_loss = torch.tensor(0.0).to(device) B, N, D = xys.shape assert(D==2) diff --git a/saverloader.py b/saverloader.py index 8f66ab8..a8a55a4 100644 --- a/saverloader.py +++ b/saverloader.py @@ -22,7 +22,7 @@ def save(ckpt_dir, optimizer, model, global_step, scheduler=None, model_ema=None torch.save(ckpt, model_path) print("saved a checkpoint: %s" % (model_path)) -def load(ckpt_dir, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None): +def load(ckpt_dir, model, device, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None): print('reading ckpt from %s' % ckpt_dir) if not os.path.exists(ckpt_dir): print('...there is no full checkpoint here!') @@ -55,7 +55,7 @@ def load(ckpt_dir, model, optimizer=None, scheduler=None, model_ema=None, step=0 # 3. load the new state dict model.load_state_dict(model_dict, strict=False) else: - checkpoint = torch.load(path) + checkpoint = torch.load(path, map_location='mps') model.load_state_dict(checkpoint['model_state_dict'], strict=False) if optimizer is not None: