Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
69e48b3
Adding torch.hub support
MarcSzafraniec Jul 6, 2020
c586c65
Adding torch.hub support
MarcSzafraniec Jul 6, 2020
23965dd
Adding torch.hub support
MarcSzafraniec Jul 6, 2020
9eb6c48
Adding torch.hub support
MarcSzafraniec Jul 6, 2020
5e6d17a
Adding torch.hub support
MarcSzafraniec Jul 6, 2020
ac959b8
Adding torch.hub support
MarcSzafraniec Jul 6, 2020
023d033
Adding torch.hub support
MarcSzafraniec Jul 6, 2020
185efa2
Adding torch.hub support
MarcSzafraniec Jul 6, 2020
fccac01
Adding torch.hub support
MarcSzafraniec Jul 6, 2020
429d50f
Adding torch.hub support
MarcSzafraniec Jul 6, 2020
369e229
Adding torch.hub support
MarcSzafraniec Jul 6, 2020
ca0cb80
Adding torch.hub support for RAFT
MarcSzafraniec Jul 6, 2020
8d10b3f
Adding torch.hub support for RAFT
MarcSzafraniec Jul 6, 2020
25c2748
Adding torch.hub support
MarcSzafraniec Jul 6, 2020
9cac81d
Adding torch.hub support
MarcSzafraniec Jul 7, 2020
2c029dd
Adding torch.hub support
MarcSzafraniec Jul 7, 2020
878a35b
Adding torch.hub support
MarcSzafraniec Jul 7, 2020
bbe6540
Adding torch.hub support
MarcSzafraniec Jul 7, 2020
fec87f2
Adding torch.hub support
MarcSzafraniec Jul 7, 2020
db1a3ef
Adding torch.hub support
MarcSzafraniec Jul 7, 2020
051e855
Adding torch.hub support
MarcSzafraniec Jul 21, 2020
6a14975
Adding torch.hub support
MarcSzafraniec Jul 21, 2020
f60e583
Adding torch.hub support
MarcSzafraniec Aug 4, 2020
3bdce9d
Adding torch.hub support
MarcSzafraniec Aug 4, 2020
8627032
Adding torch.hub support
MarcSzafraniec Aug 4, 2020
ccc124a
Adding torch.hub support
MarcSzafraniec Aug 4, 2020
8ace306
Adding torch.hub support
MarcSzafraniec Aug 4, 2020
2511e5b
Adding torch.hub support
MarcSzafraniec Aug 4, 2020
ef70326
Adding torch.hub support
MarcSzafraniec Aug 4, 2020
65d4f69
Adding torch.hub support
MarcSzafraniec Aug 4, 2020
e946a6d
Adding torch.hub support
MarcSzafraniec Aug 4, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added core/.DS_Store
Binary file not shown.
8 changes: 4 additions & 4 deletions core/modules/corr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn.functional as F
from utils.utils import bilinear_sampler, coords_grid
from ..utils.utils import bilinear_sampler, coords_grid

class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
Expand All @@ -13,7 +13,7 @@ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):

batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.view(batch*h1*w1, dim, h2, w2)

self.corr_pyramid.append(corr)
for i in range(self.num_levels):
corr = F.avg_pool2d(corr, 2, stride=2)
Expand Down Expand Up @@ -47,7 +47,7 @@ def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)

corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
return corr / torch.sqrt(torch.tensor(dim).float())
20 changes: 9 additions & 11 deletions core/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch.nn as nn
import torch.nn.functional as F

from modules.update import BasicUpdateBlock, SmallUpdateBlock
from modules.extractor import BasicEncoder, SmallEncoder
from modules.corr import CorrBlock
from utils.utils import bilinear_sampler, coords_grid, upflow8
from .modules.update import BasicUpdateBlock, SmallUpdateBlock
from .modules.extractor import BasicEncoder, SmallEncoder
from .modules.corr import CorrBlock
from .utils.utils import bilinear_sampler, coords_grid, upflow8


class RAFT(nn.Module):
Expand All @@ -19,7 +19,7 @@ def __init__(self, args):
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3

else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
Expand All @@ -31,12 +31,12 @@ def __init__(self, args):

# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)

else:
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)

Expand Down Expand Up @@ -86,14 +86,12 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True):

# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow

if upsample:
flow_up = upflow8(coords1 - coords0)
flow_predictions.append(flow_up)

else:
flow_predictions.append(coords1 - coords0)

return flow_predictions


6 changes: 6 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
dependencies = ['torch']

try:
from hubconf_models import RAFT
except ModuleNotFoundError:
from .hubconf_models import RAFT
103 changes: 103 additions & 0 deletions hubconf_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import argparse
import io
import os
import time
import urllib.request
import zipfile
import torch
from torch.nn import functional as F

try:
from core.raft import RAFT as RAFT_module
except ModuleNotFoundError:
from .core.raft import RAFT as RAFT_module

models_url = "https://www.dropbox.com/s/a2acvmczgzm6f9n/models.zip?dl=1" # dl=1 is important


__all__ = ["RAFT"]


ENV_TORCH_HOME = "TORCH_HOME"
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
DEFAULT_CACHE_DIR = "~/.cache"


def _get_torch_home():
torch_home = os.path.expanduser(
os.getenv(
ENV_TORCH_HOME, os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch")
)
)
return torch_home


def _pad8(img):
"""pad image such that dimensions are divisible by 8"""
ht, wd = img.shape[2:]
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
pad_ht1 = [pad_ht // 2, pad_ht - pad_ht // 2]
pad_wd1 = [pad_wd // 2, pad_wd - pad_wd // 2]

img = F.pad(img, pad_wd1 + pad_ht1, mode="replicate")
return img


def RAFT(pretrained=False, model_name="chairs+things", device=None, **kwargs):
"""
RAFT model (https://arxiv.org/abs/2003.12039)
model_name (str): One of 'chairs+things', 'sintel', 'kitti' and 'small'
note that for 'small', the architecture is smaller
"""

model_list = ["chairs+things", "sintel", "kitti", "small"]
if model_name not in model_list:
raise ValueError("Model should be one of " + str(model_list))

model_args = argparse.Namespace(**kwargs)
model_args.small = "small" in model_name

model = RAFT_module(model_args)
if device is None:
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
if device != "cpu":
model = torch.nn.DataParallel(model, device_ids=[device])
else:
model = torch.nn.DataParallel(model)
model.device_ids = None

if pretrained:
torch_home = _get_torch_home()
model_dir = os.path.join(torch_home, "checkpoints", "models_RAFT")
model_path = os.path.join(model_dir, "models", model_name + ".pth")
if not os.path.exists(model_dir):
os.makedirs(model_dir, exist_ok=True)
response = urllib.request.urlopen(models_url, timeout=10)
z = zipfile.ZipFile(io.BytesIO(response.read()))
z.extractall(model_dir)
else:
time.sleep(10) # Give the time for the models to be downloaded and unzipped

map_location = torch.device('cpu') if device == "cpu" else None
model.load_state_dict(torch.load(model_path, map_location=map_location))

model = model.to(device)
model.eval()
return model


def apply_model(model, images_from, images_to, iters=12, upsample=True):
"""
Applies optical flow model to the pairs of images
Args:
images_from: torch.Tensor of size [B, H, W, C] containing RGB data for
images that serve as optical flow source images
images_to: torch.Tensor of size [B, H, W, C] containing RGB data for
images that serve as optical flow destination images
Return:
optical_flow: torch.Tensor of size [B, H, W, 2]
"""
images_from, images_to = _pad8(images_from), _pad8(images_to)
with torch.no_grad():
return model(image1=images_from, image2=images_to, iters=iters, upsample=upsample)