diff --git a/demo.py b/demo.py index 5abc1da8..a3236c03 100644 --- a/demo.py +++ b/demo.py @@ -41,7 +41,7 @@ def viz(img, flo): def demo(args): model = torch.nn.DataParallel(RAFT(args)) - model.load_state_dict(torch.load(args.model)) + model.load_state_dict(torch.load(args.model, map_location=torch.device(DEVICE))) model = model.module model.to(DEVICE)