-
-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtest_.py
More file actions
19 lines (14 loc) · 624 Bytes
/
test_.py
File metadata and controls
19 lines (14 loc) · 624 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from model import MattingRefine
device = torch.device('cpu')
precision = torch.float32
model = MattingRefine(backbone='torchscript_resnet50_fp32',
backbone_scale=0.25,
refine_mode='sampling',
refine_sample_pixels=80_000)
model.load_state_dict(torch.jit.load('model/TorchScript/torchscript_resnet50_fp32.pth').eval())
model = model.eval().to(precision).to(device)
src = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
bgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
with torch.no_grad():
pha, fgr = model(src, bgr)[:2]