Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ conda install pip
pip install -r requirements.txt
```

## Running on Apple Silicon
If you want to run this code on Apple's silicon, use the following:

```
conda create --name pips python=3.9
source activate pips
conda install pytorch torchvision torchaudio -c pytorch
conda install pip
pip install -r requirements.txt
```

At the time of this fix, MPS does not support all operators required for this code to run so you will also need:

`export PYTORCH_ENABLE_MPS_FALLBACK=1`



## Demo

To download our reference model, download the weights from [Hugging Face. ![](https://img.shields.io/badge/🤗%20Hugging%20Face-Model-blue)](https://huggingface.co/aharley/pips)
Expand Down
27 changes: 17 additions & 10 deletions chain_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
random.seed(125)
np.random.seed(125)

def run_model(model, rgbs, N, sw):
rgbs = rgbs.cuda().float() # B, S, C, H, W
def run_model(model, rgbs, N, sw, device):
rgbs = rgbs.to(device).float() # B, S, C, H, W

B, S, C, H, W = rgbs.shape
rgbs_ = rgbs.reshape(B*S, C, H, W)
Expand All @@ -31,17 +31,17 @@ def run_model(model, rgbs, N, sw):
# try to pick a point on the dog, so we get an interesting trajectory
# x = torch.randint(-10, 10, size=(1, N), device=torch.device('cuda')) + 468
# y = torch.randint(-10, 10, size=(1, N), device=torch.device('cuda')) + 118
x = torch.ones((1, N), device=torch.device('cuda')) * 450.0
y = torch.ones((1, N), device=torch.device('cuda')) * 100.0
x = torch.ones((1, N), device=device) * 450.0
y = torch.ones((1, N), device=device) * 100.0
xy0 = torch.stack([x, y], dim=-1) # B, N, 2
_, S, C, H, W = rgbs.shape

trajs_e = torch.zeros((B, S, N, 2), dtype=torch.float32, device='cuda')
trajs_e = torch.zeros((B, S, N, 2), dtype=torch.float32, device=device)
for n in range(N):
# print('working on keypoint %d/%d' % (n+1, N))
cur_frame = 0
done = False
traj_e = torch.zeros((B, S, 2), dtype=torch.float32, device='cuda')
traj_e = torch.zeros((B, S, 2), dtype=torch.float32, device=device)
traj_e[:,0] = xy0[:,n] # B, 1, 2 # set first position
feat_init = None
while not done:
Expand All @@ -51,7 +51,7 @@ def run_model(model, rgbs, N, sw):
S_local = rgb_seq.shape[1]
rgb_seq = torch.cat([rgb_seq, rgb_seq[:,-1].unsqueeze(1).repeat(1,8-S_local,1,1,1)], dim=1)

outs = model(traj_e[:,cur_frame].reshape(1, -1, 2), rgb_seq, iters=6, feat_init=feat_init, return_feat=True)
outs = model(traj_e[:,cur_frame].reshape(1, -1, 2), rgb_seq, device, iters=6, feat_init=feat_init, return_feat=True)
preds = outs[0]
vis = outs[2] # B, S, 1
feat_init = outs[3]
Expand Down Expand Up @@ -143,10 +143,17 @@ def main():

global_step = 0

model = Pips(stride=4).cuda()
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

model = Pips(stride=4).to(device)
parameters = list(model.parameters())
if init_dir:
_ = saverloader.load(init_dir, model)
_ = saverloader.load(init_dir, model, device=device)
global_step = 0
model.eval()

Expand Down Expand Up @@ -179,7 +186,7 @@ def main():
iter_start_time = time.time()

with torch.no_grad():
trajs_e = run_model(model, rgbs, N, sw_t)
trajs_e = run_model(model, rgbs, N, sw_t, device)

iter_time = time.time()-iter_start_time
print('%s; step %06d/%d; rtime %.2f; itime %.2f' % (
Expand Down
23 changes: 16 additions & 7 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
random.seed(125)
np.random.seed(125)

def run_model(model, rgbs, N, sw):
rgbs = rgbs.cuda().float() # B, S, C, H, W
def run_model(model, rgbs, N, sw, device):
rgbs = rgbs.to(device).float() # B, S, C, H, W

B, S, C, H, W = rgbs.shape
rgbs_ = rgbs.reshape(B*S, C, H, W)
Expand All @@ -30,14 +30,14 @@ def run_model(model, rgbs, N, sw):

# pick N points to track; we'll use a uniform grid
N_ = np.sqrt(N).round().astype(np.int32)
grid_y, grid_x = utils.basic.meshgrid2d(B, N_, N_, stack=False, norm=False, device='cuda')
grid_y, grid_x = utils.basic.meshgrid2d(B, N_, N_, stack=False, norm=False, device=device)
grid_y = 8 + grid_y.reshape(B, -1)/float(N_-1) * (H-16)
grid_x = 8 + grid_x.reshape(B, -1)/float(N_-1) * (W-16)
xy = torch.stack([grid_x, grid_y], dim=-1) # B, N_*N_, 2
_, S, C, H, W = rgbs.shape

print_stats('rgbs', rgbs)
preds, preds_anim, vis_e, stats = model(xy, rgbs, iters=6)
preds, preds_anim, vis_e, stats = model(xy, rgbs, device, iters=6, )
trajs_e = preds[-1]
print_stats('trajs_e', trajs_e)

Expand Down Expand Up @@ -111,10 +111,19 @@ def main():

global_step = 0

model = Pips(stride=4).cuda()
# Pick a device
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')


model = Pips(stride=4).to(device)
parameters = list(model.parameters())
if init_dir:
_ = saverloader.load(init_dir, model)
_ = saverloader.load(init_dir, model, device=device)
global_step = 0
model.eval()

Expand Down Expand Up @@ -147,7 +156,7 @@ def main():
iter_start_time = time.time()

with torch.no_grad():
trajs_e = run_model(model, rgbs, N, sw_t)
trajs_e = run_model(model, rgbs, N, sw_t, device)

iter_time = time.time()-iter_start_time
print('%s; step %06d/%d; rtime %.2f; itime %.2f' % (
Expand Down
4 changes: 2 additions & 2 deletions nets/pips.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,8 @@ def __init__(self, S=8, stride=8):
nn.Linear(self.latent_dim, 1),
)

def forward(self, xys, rgbs, coords_init=None, feat_init=None, iters=3, trajs_g=None, vis_g=None, valids=None, sw=None, return_feat=False, is_train=False):
total_loss = torch.tensor(0.0).cuda()
def forward(self, xys, rgbs, device, coords_init=None, feat_init=None, iters=3, trajs_g=None, vis_g=None, valids=None, sw=None, return_feat=False, is_train=False):
total_loss = torch.tensor(0.0).to(device)

B, N, D = xys.shape
assert(D==2)
Expand Down
4 changes: 2 additions & 2 deletions saverloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def save(ckpt_dir, optimizer, model, global_step, scheduler=None, model_ema=None
torch.save(ckpt, model_path)
print("saved a checkpoint: %s" % (model_path))

def load(ckpt_dir, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None):
def load(ckpt_dir, model, device, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None):
print('reading ckpt from %s' % ckpt_dir)
if not os.path.exists(ckpt_dir):
print('...there is no full checkpoint here!')
Expand Down Expand Up @@ -55,7 +55,7 @@ def load(ckpt_dir, model, optimizer=None, scheduler=None, model_ema=None, step=0
# 3. load the new state dict
model.load_state_dict(model_dict, strict=False)
else:
checkpoint = torch.load(path)
checkpoint = torch.load(path, map_location='mps')
model.load_state_dict(checkpoint['model_state_dict'], strict=False)

if optimizer is not None:
Expand Down