-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
37 lines (27 loc) · 1.03 KB
/
main.py
File metadata and controls
37 lines (27 loc) · 1.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch.autograd import Variable
from torchvision.transforms import Compose
from torch.utils.data import DataLoader
from models.stereo import StereoRegression
from models.unary import ResnetUnary
from models.volume import CostVolumeConcat
from models.regression import ResnetRegression
from models.classification import SoftArgmin
from data.dataset_paths import get_all_kitti_paths
from data.dataset import KittiDataset
from data.transform import Padding, RandomCrop, ToTensor
model = StereoRegression(
ResnetUnary(),
CostVolumeConcat(),
ResnetRegression(),
SoftArgmin(),
)
size = (1, 3, 512, 256)
model.cuda()
# disparity = out.data.cpu().numpy().squeeze()
transform = Compose((RandomCrop((256, 512)), ToTensor()))
dataset = KittiDataset(get_all_kitti_paths(), transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True,num_workers=4)
example = dataset[0]
l, r, d = example['left'], example['right'], example['disparity']
disp = model(Variable(l.unsqueeze_(0).cuda()), Variable(r.unsqueeze_(0).cuda()))