diff --git a/core/.DS_Store b/core/.DS_Store new file mode 100644 index 00000000..4a5bdfbe Binary files /dev/null and b/core/.DS_Store differ diff --git a/core/modules/corr.py b/core/modules/corr.py index 63db794a..002dfed9 100644 --- a/core/modules/corr.py +++ b/core/modules/corr.py @@ -1,6 +1,6 @@ import torch import torch.nn.functional as F -from utils.utils import bilinear_sampler, coords_grid +from ..utils.utils import bilinear_sampler, coords_grid class CorrBlock: def __init__(self, fmap1, fmap2, num_levels=4, radius=4): @@ -13,7 +13,7 @@ def __init__(self, fmap1, fmap2, num_levels=4, radius=4): batch, h1, w1, dim, h2, w2 = corr.shape corr = corr.view(batch*h1*w1, dim, h2, w2) - + self.corr_pyramid.append(corr) for i in range(self.num_levels): corr = F.avg_pool2d(corr, 2, stride=2) @@ -47,7 +47,7 @@ def corr(fmap1, fmap2): batch, dim, ht, wd = fmap1.shape fmap1 = fmap1.view(batch, dim, ht*wd) fmap2 = fmap2.view(batch, dim, ht*wd) - + corr = torch.matmul(fmap1.transpose(1,2), fmap2) corr = corr.view(batch, ht, wd, 1, ht, wd) - return corr / torch.sqrt(torch.tensor(dim).float()) \ No newline at end of file + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/core/raft.py b/core/raft.py index e14a54a6..b12bffd6 100644 --- a/core/raft.py +++ b/core/raft.py @@ -3,10 +3,10 @@ import torch.nn as nn import torch.nn.functional as F -from modules.update import BasicUpdateBlock, SmallUpdateBlock -from modules.extractor import BasicEncoder, SmallEncoder -from modules.corr import CorrBlock -from utils.utils import bilinear_sampler, coords_grid, upflow8 +from .modules.update import BasicUpdateBlock, SmallUpdateBlock +from .modules.extractor import BasicEncoder, SmallEncoder +from .modules.corr import CorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 class RAFT(nn.Module): @@ -19,7 +19,7 @@ def __init__(self, args): self.context_dim = cdim = 64 args.corr_levels = 4 args.corr_radius = 3 - + else: self.hidden_dim = hdim = 128 self.context_dim = cdim = 128 @@ -31,12 +31,12 @@ def __init__(self, args): # feature network, context network, and update block if args.small: - self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) + self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) else: - self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) + self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) @@ -86,14 +86,12 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True): # F(t+1) = F(t) + \Delta(t) coords1 = coords1 + delta_flow - + if upsample: flow_up = upflow8(coords1 - coords0) flow_predictions.append(flow_up) - + else: flow_predictions.append(coords1 - coords0) return flow_predictions - - diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 00000000..86824639 --- /dev/null +++ b/hubconf.py @@ -0,0 +1,6 @@ +dependencies = ['torch'] + +try: + from hubconf_models import RAFT +except ModuleNotFoundError: + from .hubconf_models import RAFT diff --git a/hubconf_models.py b/hubconf_models.py new file mode 100644 index 00000000..09461700 --- /dev/null +++ b/hubconf_models.py @@ -0,0 +1,103 @@ +import argparse +import io +import os +import time +import urllib.request +import zipfile +import torch +from torch.nn import functional as F + +try: + from core.raft import RAFT as RAFT_module +except ModuleNotFoundError: + from .core.raft import RAFT as RAFT_module + +models_url = "https://www.dropbox.com/s/a2acvmczgzm6f9n/models.zip?dl=1" # dl=1 is important + + +__all__ = ["RAFT"] + + +ENV_TORCH_HOME = "TORCH_HOME" +ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" +DEFAULT_CACHE_DIR = "~/.cache" + + +def _get_torch_home(): + torch_home = os.path.expanduser( + os.getenv( + ENV_TORCH_HOME, os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch") + ) + ) + return torch_home + + +def _pad8(img): + """pad image such that dimensions are divisible by 8""" + ht, wd = img.shape[2:] + pad_ht = (((ht // 8) + 1) * 8 - ht) % 8 + pad_wd = (((wd // 8) + 1) * 8 - wd) % 8 + pad_ht1 = [pad_ht // 2, pad_ht - pad_ht // 2] + pad_wd1 = [pad_wd // 2, pad_wd - pad_wd // 2] + + img = F.pad(img, pad_wd1 + pad_ht1, mode="replicate") + return img + + +def RAFT(pretrained=False, model_name="chairs+things", device=None, **kwargs): + """ + RAFT model (https://arxiv.org/abs/2003.12039) + model_name (str): One of 'chairs+things', 'sintel', 'kitti' and 'small' + note that for 'small', the architecture is smaller + """ + + model_list = ["chairs+things", "sintel", "kitti", "small"] + if model_name not in model_list: + raise ValueError("Model should be one of " + str(model_list)) + + model_args = argparse.Namespace(**kwargs) + model_args.small = "small" in model_name + + model = RAFT_module(model_args) + if device is None: + device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" + if device != "cpu": + model = torch.nn.DataParallel(model, device_ids=[device]) + else: + model = torch.nn.DataParallel(model) + model.device_ids = None + + if pretrained: + torch_home = _get_torch_home() + model_dir = os.path.join(torch_home, "checkpoints", "models_RAFT") + model_path = os.path.join(model_dir, "models", model_name + ".pth") + if not os.path.exists(model_dir): + os.makedirs(model_dir, exist_ok=True) + response = urllib.request.urlopen(models_url, timeout=10) + z = zipfile.ZipFile(io.BytesIO(response.read())) + z.extractall(model_dir) + else: + time.sleep(10) # Give the time for the models to be downloaded and unzipped + + map_location = torch.device('cpu') if device == "cpu" else None + model.load_state_dict(torch.load(model_path, map_location=map_location)) + + model = model.to(device) + model.eval() + return model + + +def apply_model(model, images_from, images_to, iters=12, upsample=True): + """ + Applies optical flow model to the pairs of images + Args: + images_from: torch.Tensor of size [B, H, W, C] containing RGB data for + images that serve as optical flow source images + images_to: torch.Tensor of size [B, H, W, C] containing RGB data for + images that serve as optical flow destination images + Return: + optical_flow: torch.Tensor of size [B, H, W, 2] + """ + images_from, images_to = _pad8(images_from), _pad8(images_to) + with torch.no_grad(): + return model(image1=images_from, image2=images_to, iters=iters, upsample=upsample)