+
+
+
+ SuperCATs: Cost Aggregation with Transformers for Sparse Correspondence
+
+ ICCE-Asia 2022
+
+
+
+
+
+
+ -
+ Seungjun Lee
+
+ -
+ Seungjun An
+
+ -
+ Sunghwan Hong
+
+ -
+ Seokju Cho
+
+ -
+ Jisu Nam
+
+ -
+ Susung Hong
+
+ -
+
+ Seungryong Kim
+
+
+
+ Korea University
+
+
+
+
+
+
+
+
+
+
+
+

+

+
+
+
+
+ Comparison between SuperCATs (left) and SuperGlue (right).
+
+
+
+
+
+
+
+
+ Abstract
+
+
+ In this work, we introduce a novel network, namely SuperCATs, which aims to find a correspondence field between visually similar images. SuperCATs stands on the shoulder of the recently proposed matching networks, SuperGlue and CATs, taking the merits of both for constructing an integrative framework. Specifically, given keypoints and corresponding descriptors, we first apply attentional aggregation consisting of self- and cross- graph neural network to obtain feature descriptors. Subsequently, we construct a cost volume using the descriptors, which then undergoes a tranformer aggregator for cost aggregation. With this approach, we manage to replace the handcrafted module based on solving an optimal transport problem initially included in SuperGlue with a transformer well known for its global receptive fields, making our approach more robust to severe deformations. We conduct experiments to demonstrate the effectiveness of the proposed method, and show that the proposed model is on par with SuperGlue for both indoor and outdoor scenes.
+
+
+
+
+
+
+
+
+
+ Architecture
+
+
+
+

+
+
+
+ Overall network architecture of SuperCATs.
+
+
+
+
+

+
+
+
+ Structure of Transformer Aggregator.
+
+
+
+
+
+
+
+
+
+
+
+ Acknowledgements
+
+
+ Thanks to our family Podo (cat), Aru (dog) and Dubu (dog) for their support. We love you.
+
+ The website template was borrowed from Michaël Gharbi.
+
+
+
+
+
+
diff --git a/make_load.py b/make_load.py
deleted file mode 100644
index a2389fc..0000000
--- a/make_load.py
+++ /dev/null
@@ -1,152 +0,0 @@
-from pathlib import Path
-import argparse
-import random
-import numpy as np
-import matplotlib.cm as cm
-import torch
-import torch.nn as nn
-from torch.autograd import Variable
-import os
-import torch.multiprocessing
-from tqdm import tqdm
-
-import cv2
-from scipy.spatial.distance import cdist
-
-from models.utils import (compute_pose_error, compute_epipolar_error,
- estimate_pose, make_matching_plot,
- error_colormap, AverageTimer, pose_auc, read_image,
- rotate_intrinsics, rotate_pose_inplane,
- scale_intrinsics, read_image_modified, frame2tensor)
-
-from models.matching import Matching
-from models.matchingsuperglue import Matching_ori
-from sjlee.loss import loss_superglue
-
-torch.set_grad_enabled(True)
-torch.multiprocessing.set_sharing_strategy('file_system')
-
-parser = argparse.ArgumentParser(
- description='Image pair matching and pose evaluation with SuperGlue',
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-
-parser.add_argument(
- '--viz', action='store_true',
- help='Visualize the matches and dump the plots')
-parser.add_argument(
- '--eval', action='store_true',
- help='Perform the evaluation'
- ' (requires ground truth pose and intrinsics)')
-
-parser.add_argument(
- '--superglue', choices={'indoor', 'outdoor'}, default='indoor',
- help='SuperGlue weights')
-parser.add_argument(
- '--max_keypoints', type=int, default=1023,
- help='Maximum number of keypoints detected by Superpoint'
- ' (\'-1\' keeps all keypoints)')
-parser.add_argument(
- '--keypoint_threshold', type=float, default=0.005,
- help='SuperPoint keypoint detector confidence threshold')
-parser.add_argument(
- '--nms_radius', type=int, default=4,
- help='SuperPoint Non Maximum Suppression (NMS) radius'
- ' (Must be positive)')
-parser.add_argument(
- '--sinkhorn_iterations', type=int, default=20,
- help='Number of Sinkhorn iterations performed by SuperGlue')
-parser.add_argument(
- '--match_threshold', type=float, default=0.2,
- help='SuperGlue match threshold')
-
-parser.add_argument(
- '--resize', type=int, nargs='+', default=[640, 480],
- help='Resize the input image before running inference. If two numbers, '
- 'resize to the exact dimensions, if one number, resize the max '
- 'dimension, if -1, do not resize')
-parser.add_argument(
- '--resize_float', action='store_true',
- help='Resize the image after casting uint8 to float')
-
-parser.add_argument(
- '--cache', action='store_true',
- help='Skip the pair if output .npz files are already found')
-parser.add_argument(
- '--show_keypoints', action='store_true',
- help='Plot the keypoints in addition to the matches')
-parser.add_argument(
- '--fast_viz', action='store_true',
- help='Use faster image visualization based on OpenCV instead of Matplotlib')
-parser.add_argument(
- '--viz_extension', type=str, default='png', choices=['png', 'pdf'],
- help='Visualization file extension. Use pdf for highest-quality.')
-
-parser.add_argument(
- '--opencv_display', action='store_true',
- help='Visualize via OpenCV before saving output images')
-parser.add_argument(
- '--eval_pairs_list', type=str, default='assets/scannet_sample_pairs_with_gt.txt',
- help='Path to the list of image pairs for evaluation')
-parser.add_argument(
- '--shuffle', action='store_true',
- help='Shuffle ordering of pairs before processing')
-parser.add_argument(
- '--max_length', type=int, default=-1,
- help='Maximum number of pairs to evaluate')
-
-parser.add_argument(
- '--eval_input_dir', type=str, default='assets/scannet_sample_images/',
- help='Path to the directory that contains the images')
-parser.add_argument(
- '--eval_output_dir', type=str, default='test_matches',
- help='Path to the directory in which the .npz results and optional,'
- 'visualizations are written')
-parser.add_argument(
- '--learning_rate', type=float, default=0.0001, #0.0001
- help='Learning rate')
-
-parser.add_argument(
- '--batch_size', type=int, default=1,
- help='batch_size')
-parser.add_argument(
- '--train_path', type=str, default='/home/cvlab09/projects/seungjun_an/dataset/train2014/',
- help='Path to the directory of training imgs.')
-parser.add_argument(
- '--epoch', type=int, default=1,
- help='Number of epoches')
-
-
-
-
-if __name__ == '__main__':
- opt = parser.parse_args()
- print(opt)
-
- # make sure the flags are properly used
- assert not (opt.opencv_display and not opt.viz), 'Must use --viz with --opencv_display'
- assert not (opt.opencv_display and not opt.fast_viz), 'Cannot use --opencv_display without --fast_viz'
- assert not (opt.fast_viz and not opt.viz), 'Must use --viz with --fast_viz'
- assert not (opt.fast_viz and opt.viz_extension == 'pdf'), 'Cannot use pdf extension with --fast_viz'
-
- numOftrainSet = 10
-
- # store viz results
- eval_output_dir = Path(opt.eval_output_dir)
- eval_output_dir.mkdir(exist_ok=True, parents=True)
- print('Will write visualization images to',
- 'directory \"{}\"'.format(eval_output_dir))
- config = {
- 'superpoint': {
- 'nms_radius': opt.nms_radius,
- 'keypoint_threshold': opt.keypoint_threshold,
- 'max_keypoints': opt.max_keypoints
- },
- 'superglue': {
- 'weights': opt.superglue,
- 'sinkhorn_iterations': opt.sinkhorn_iterations,
- 'match_threshold': opt.match_threshold,
- }
- }
- matching = Matching(config).eval().to('cuda')
- matching = torch.load('/home/cvlab09/projects/seungjun_an/superglue_test/model_epoch_1.pth')
- torch.save(matching.state_dict(), 'model_state_dict_epoch_1.pth')
\ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
deleted file mode 100755
index e69de29..0000000
diff --git a/models/__pycache__/__init__.cpython-37.pyc b/models/__pycache__/__init__.cpython-37.pyc
deleted file mode 100644
index 671b557..0000000
Binary files a/models/__pycache__/__init__.cpython-37.pyc and /dev/null differ
diff --git a/models/__pycache__/__init__.cpython-38.pyc b/models/__pycache__/__init__.cpython-38.pyc
deleted file mode 100644
index 94e80cc..0000000
Binary files a/models/__pycache__/__init__.cpython-38.pyc and /dev/null differ
diff --git a/models/__pycache__/__init__.cpython-39.pyc b/models/__pycache__/__init__.cpython-39.pyc
deleted file mode 100644
index 3916c66..0000000
Binary files a/models/__pycache__/__init__.cpython-39.pyc and /dev/null differ
diff --git a/models/__pycache__/matching.cpython-38.pyc b/models/__pycache__/matching.cpython-38.pyc
deleted file mode 100644
index ef8d31d..0000000
Binary files a/models/__pycache__/matching.cpython-38.pyc and /dev/null differ
diff --git a/models/__pycache__/matching.cpython-39.pyc b/models/__pycache__/matching.cpython-39.pyc
deleted file mode 100644
index f7cdf40..0000000
Binary files a/models/__pycache__/matching.cpython-39.pyc and /dev/null differ
diff --git a/models/__pycache__/matchingForTraining.cpython-37.pyc b/models/__pycache__/matchingForTraining.cpython-37.pyc
deleted file mode 100644
index 66b217a..0000000
Binary files a/models/__pycache__/matchingForTraining.cpython-37.pyc and /dev/null differ
diff --git a/models/__pycache__/matchingForTraining.cpython-38.pyc b/models/__pycache__/matchingForTraining.cpython-38.pyc
deleted file mode 100644
index c6f4350..0000000
Binary files a/models/__pycache__/matchingForTraining.cpython-38.pyc and /dev/null differ
diff --git a/models/__pycache__/matchingForTraining.cpython-39.pyc b/models/__pycache__/matchingForTraining.cpython-39.pyc
deleted file mode 100644
index 606416d..0000000
Binary files a/models/__pycache__/matchingForTraining.cpython-39.pyc and /dev/null differ
diff --git a/models/__pycache__/matching_backup.cpython-38.pyc b/models/__pycache__/matching_backup.cpython-38.pyc
deleted file mode 100644
index f1eabcb..0000000
Binary files a/models/__pycache__/matching_backup.cpython-38.pyc and /dev/null differ
diff --git a/models/__pycache__/matchingsuperglue.cpython-38.pyc b/models/__pycache__/matchingsuperglue.cpython-38.pyc
deleted file mode 100644
index feb3e4f..0000000
Binary files a/models/__pycache__/matchingsuperglue.cpython-38.pyc and /dev/null differ
diff --git a/models/__pycache__/superglue.cpython-37.pyc b/models/__pycache__/superglue.cpython-37.pyc
deleted file mode 100644
index 1bb1420..0000000
Binary files a/models/__pycache__/superglue.cpython-37.pyc and /dev/null differ
diff --git a/models/__pycache__/superglue.cpython-38.pyc b/models/__pycache__/superglue.cpython-38.pyc
deleted file mode 100644
index 0a3ad9f..0000000
Binary files a/models/__pycache__/superglue.cpython-38.pyc and /dev/null differ
diff --git a/models/__pycache__/superglue.cpython-39.pyc b/models/__pycache__/superglue.cpython-39.pyc
deleted file mode 100644
index 06813c1..0000000
Binary files a/models/__pycache__/superglue.cpython-39.pyc and /dev/null differ
diff --git a/models/__pycache__/superglue2.cpython-38.pyc b/models/__pycache__/superglue2.cpython-38.pyc
deleted file mode 100644
index c413c03..0000000
Binary files a/models/__pycache__/superglue2.cpython-38.pyc and /dev/null differ
diff --git a/models/__pycache__/superglue2.cpython-39.pyc b/models/__pycache__/superglue2.cpython-39.pyc
deleted file mode 100644
index a32f740..0000000
Binary files a/models/__pycache__/superglue2.cpython-39.pyc and /dev/null differ
diff --git a/models/__pycache__/superpoint.cpython-37.pyc b/models/__pycache__/superpoint.cpython-37.pyc
deleted file mode 100644
index a9890e3..0000000
Binary files a/models/__pycache__/superpoint.cpython-37.pyc and /dev/null differ
diff --git a/models/__pycache__/superpoint.cpython-38.pyc b/models/__pycache__/superpoint.cpython-38.pyc
deleted file mode 100644
index fb20221..0000000
Binary files a/models/__pycache__/superpoint.cpython-38.pyc and /dev/null differ
diff --git a/models/__pycache__/superpoint.cpython-39.pyc b/models/__pycache__/superpoint.cpython-39.pyc
deleted file mode 100644
index 8b1a999..0000000
Binary files a/models/__pycache__/superpoint.cpython-39.pyc and /dev/null differ
diff --git a/models/__pycache__/utils.cpython-37.pyc b/models/__pycache__/utils.cpython-37.pyc
deleted file mode 100644
index 74bd80f..0000000
Binary files a/models/__pycache__/utils.cpython-37.pyc and /dev/null differ
diff --git a/models/__pycache__/utils.cpython-38.pyc b/models/__pycache__/utils.cpython-38.pyc
deleted file mode 100644
index 3632a97..0000000
Binary files a/models/__pycache__/utils.cpython-38.pyc and /dev/null differ
diff --git a/models/__pycache__/utils.cpython-39.pyc b/models/__pycache__/utils.cpython-39.pyc
deleted file mode 100644
index b47e899..0000000
Binary files a/models/__pycache__/utils.cpython-39.pyc and /dev/null differ
diff --git a/models/matching.py b/models/matching.py
deleted file mode 100755
index f53f8ce..0000000
--- a/models/matching.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# %BANNER_BEGIN%
-# ---------------------------------------------------------------------
-# %COPYRIGHT_BEGIN%
-#
-# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
-#
-# Unpublished Copyright (c) 2020
-# Magic Leap, Inc., All Rights Reserved.
-#
-# NOTICE: All information contained herein is, and remains the property
-# of COMPANY. The intellectual and technical concepts contained herein
-# are proprietary to COMPANY and may be covered by U.S. and Foreign
-# Patents, patents in process, and are protected by trade secret or
-# copyright law. Dissemination of this information or reproduction of
-# this material is strictly forbidden unless prior written permission is
-# obtained from COMPANY. Access to the source code contained herein is
-# hereby forbidden to anyone except current COMPANY employees, managers
-# or contractors who have executed Confidentiality and Non-disclosure
-# agreements explicitly covering such access.
-#
-# The copyright notice above does not evidence any actual or intended
-# publication or disclosure of this source code, which includes
-# information that is confidential and/or proprietary, and is a trade
-# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
-# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
-# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
-# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
-# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
-# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
-# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
-# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
-#
-# %COPYRIGHT_END%
-# ----------------------------------------------------------------------
-# %AUTHORS_BEGIN%
-#
-# Originating Authors: Paul-Edouard Sarlin
-#
-# %AUTHORS_END%
-# --------------------------------------------------------------------*/
-# %BANNER_END%
-
-import torch
-
-from .superpoint import SuperPoint
-from .superglue2 import SuperGlue
-from sjlee.IMC import SimpleSuperCATs
-
-class Matching(torch.nn.Module):
- """ Image Matching Frontend (SuperPoint + SuperGlue) """
- def __init__(self, config={}):
- super().__init__()
- self.superpoint = SuperPoint(config.get('superpoint', {}))
- self.superglue = SuperGlue(config.get('superglue', {}))
- self.simsuper = SimpleSuperCATs(config.get('superglue', {}))
-
- def forward(self, data):
- """ Run SuperPoint (optionally) and SuperGlue
- SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input
- Args:
- data: dictionary with minimal keys: ['image0', 'image1']
- """
- pred = {}
-
- # Extract SuperPoint (keypoints, scores, descriptors) if not provided
- with torch.no_grad():
- if 'keypoints0' not in data:
- pred0 = self.superpoint({'image': data['image0']})
- pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
- if 'keypoints1' not in data:
- pred1 = self.superpoint({'image': data['image1']})
- pred = {**pred, **{k+'1': v for k, v in pred1.items()}}
-
- # Batch all features
- # We should either have i) one image per batch, or
- # ii) the same number of local features for all images in the batch.
- data = {**data, **pred}
-
- for k in data:
- if isinstance(data[k], (list, tuple)):
- data[k] = torch.stack(data[k])
- data[k].requres_grad = True
-
-
- data['keypoints0'], data['keypoints1'] = data['keypoints0'].unsqueeze(0), data['keypoints1'].unsqueeze(0)
- data['scores0'], data['scores1'] = data['scores0'].transpose(0,1), data['scores1'].transpose(0,1)
- data['descriptors0'], data['descriptors1'] = data['descriptors0'].transpose(0, 1), data['descriptors1'].transpose(0, 1)
-
- for k in data:
- if k == 'file_name' or k == 'skip_train':
- continue
- data[k].requres_grad = True
- data[k] = data[k]
- ##print(data.keys())
- ##print(data['keypoints0'].size())
- ##print(data['scores0'].size())
- ##print(data['descriptors0'].size())
-
- # Perform the matching
- #pred = {**pred, **self.superglue(data)}
-
- scores, data2 = self.simsuper(data)
-
- return scores, data2, data
diff --git a/models/matchingsuperglue.py b/models/matchingsuperglue.py
deleted file mode 100755
index 06ad7fd..0000000
--- a/models/matchingsuperglue.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# %BANNER_BEGIN%
-# ---------------------------------------------------------------------
-# %COPYRIGHT_BEGIN%
-#
-# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
-#
-# Unpublished Copyright (c) 2020
-# Magic Leap, Inc., All Rights Reserved.
-#
-# NOTICE: All information contained herein is, and remains the property
-# of COMPANY. The intellectual and technical concepts contained herein
-# are proprietary to COMPANY and may be covered by U.S. and Foreign
-# Patents, patents in process, and are protected by trade secret or
-# copyright law. Dissemination of this information or reproduction of
-# this material is strictly forbidden unless prior written permission is
-# obtained from COMPANY. Access to the source code contained herein is
-# hereby forbidden to anyone except current COMPANY employees, managers
-# or contractors who have executed Confidentiality and Non-disclosure
-# agreements explicitly covering such access.
-#
-# The copyright notice above does not evidence any actual or intended
-# publication or disclosure of this source code, which includes
-# information that is confidential and/or proprietary, and is a trade
-# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
-# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
-# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
-# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
-# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
-# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
-# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
-# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
-#
-# %COPYRIGHT_END%
-# ----------------------------------------------------------------------
-# %AUTHORS_BEGIN%
-#
-# Originating Authors: Paul-Edouard Sarlin
-#
-# %AUTHORS_END%
-# --------------------------------------------------------------------*/
-# %BANNER_END%
-
-import torch
-
-from .superpoint import SuperPoint
-from .superglue2 import SuperGlue
-from sjlee_backup.IMCsuperglue import SimpleSuperCATs
-
-class Matching_ori(torch.nn.Module):
- """ Image Matching Frontend (SuperPoint + SuperGlue) """
- def __init__(self, config={}):
- super().__init__()
- self.superpoint = SuperPoint(config.get('superpoint', {}))
- self.superglue = SuperGlue(config.get('superglue', {}))
- self.simsuper = SimpleSuperCATs(config.get('superglue', {}))
-
- def forward(self, data):
- """ Run SuperPoint (optionally) and SuperGlue
- SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input
- Args:
- data: dictionary with minimal keys: ['image0', 'image1']
- """
- pred = {}
-
- # Extract SuperPoint (keypoints, scores, descriptors) if not provided
- with torch.no_grad():
- if 'keypoints0' not in data:
- pred0 = self.superpoint({'image': data['image0']})
- pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
- if 'keypoints1' not in data:
- pred1 = self.superpoint({'image': data['image1']})
- pred = {**pred, **{k+'1': v for k, v in pred1.items()}}
-
- # Batch all features
- # We should either have i) one image per batch, or
- # ii) the same number of local features for all images in the batch.
- data = {**data, **pred}
-
- for k in data:
- if isinstance(data[k], (list, tuple)):
- data[k] = torch.stack(data[k])
- data[k].requres_grad = True
-
- self.superglue(data)
-
- data['keypoints0'], data['keypoints1'] = data['keypoints0'].unsqueeze(0), data['keypoints1'].unsqueeze(0)
- data['scores0'], data['scores1'] = data['scores0'].transpose(0,1), data['scores1'].transpose(0,1)
- data['descriptors0'], data['descriptors1'] = data['descriptors0'].transpose(0, 1), data['descriptors1'].transpose(0, 1)
-
- for k in data:
- if k == 'file_name' or k == 'skip_train':
- continue
- data[k].requres_grad = True
-
- ##print(data.keys())
- ##print(data['keypoints0'].size())
- ##print(data['scores0'].size())
- ##print(data['descriptors0'].size())
-
- # Perform the matching
- #pred = {**pred, **self.superglue(data)}
-
- scores, data2 = self.simsuper(data)
-
- return scores, data2, data
diff --git a/models/superglue2.py b/models/superglue2.py
deleted file mode 100644
index 9605a75..0000000
--- a/models/superglue2.py
+++ /dev/null
@@ -1,290 +0,0 @@
-# %BANNER_BEGIN%
-# ---------------------------------------------------------------------
-# %COPYRIGHT_BEGIN%
-#
-# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
-#
-# Unpublished Copyright (c) 2020
-# Magic Leap, Inc., All Rights Reserved.
-#
-# NOTICE: All information contained herein is, and remains the property
-# of COMPANY. The intellectual and technical concepts contained herein
-# are proprietary to COMPANY and may be covered by U.S. and Foreign
-# Patents, patents in process, and are protected by trade secret or
-# copyright law. Dissemination of this information or reproduction of
-# this material is strictly forbidden unless prior written permission is
-# obtained from COMPANY. Access to the source code contained herein is
-# hereby forbidden to anyone except current COMPANY employees, managers
-# or contractors who have executed Confidentiality and Non-disclosure
-# agreements explicitly covering such access.
-#
-# The copyright notice above does not evidence any actual or intended
-# publication or disclosure of this source code, which includes
-# information that is confidential and/or proprietary, and is a trade
-# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
-# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
-# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
-# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
-# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
-# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
-# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
-# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
-#
-# %COPYRIGHT_END%
-# ----------------------------------------------------------------------
-# %AUTHORS_BEGIN%
-#
-# Originating Authors: Paul-Edouard Sarlin
-#
-# %AUTHORS_END%
-# --------------------------------------------------------------------*/
-# %BANNER_END%
-
-from copy import deepcopy
-from pathlib import Path
-from typing import List, Tuple
-
-import torch
-from torch import nn
-
-
-def MLP(channels: List[int], do_bn: bool = True) -> nn.Module:
- """ Multi-layer perceptron """
- n = len(channels)
- layers = []
- for i in range(1, n):
- layers.append(
- nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
- if i < (n-1):
- if do_bn:
- layers.append(nn.BatchNorm1d(channels[i]))
- layers.append(nn.ReLU())
- return nn.Sequential(*layers)
-
-
-def normalize_keypoints(kpts, image_shape):
- """ Normalize keypoints locations based on image image_shape"""
- _, _, height, width = image_shape
- one = kpts.new_tensor(1)
- size = torch.stack([one*width, one*height])[None]
- center = size / 2
- scaling = size.max(1, keepdim=True).values * 0.7
- return (kpts - center[:, None, :]) / scaling[:, None, :]
-
-
-class KeypointEncoder(nn.Module):
- """ Joint encoding of visual appearance and location using MLPs"""
- def __init__(self, feature_dim: int, layers: List[int]) -> None:
- super().__init__()
- self.encoder = MLP([3] + layers + [feature_dim])
- nn.init.constant_(self.encoder[-1].bias, 0.0)
-
- def forward(self, kpts, scores):
- inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)]
- return self.encoder(torch.cat(inputs, dim=1))
-
-
-def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]:
- dim = query.shape[1]
- scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
- prob = torch.nn.functional.softmax(scores, dim=-1)
- return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
-
-
-class MultiHeadedAttention(nn.Module):
- """ Multi-head attention to increase model expressivitiy """
- def __init__(self, num_heads: int, d_model: int):
- super().__init__()
- assert d_model % num_heads == 0
- self.dim = d_model // num_heads
- self.num_heads = num_heads
- self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
- self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
-
- def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
- batch_dim = query.size(0)
- query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
- for l, x in zip(self.proj, (query, key, value))]
- x, _ = attention(query, key, value)
- return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))
-
-
-class AttentionalPropagation(nn.Module):
- def __init__(self, feature_dim: int, num_heads: int):
- super().__init__()
- self.attn = MultiHeadedAttention(num_heads, feature_dim)
- self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
- nn.init.constant_(self.mlp[-1].bias, 0.0)
-
- def forward(self, x: torch.Tensor, source: torch.Tensor) -> torch.Tensor:
- message = self.attn(x, source, source)
- return self.mlp(torch.cat([x, message], dim=1))
-
-
-class AttentionalGNN(nn.Module):
- def __init__(self, feature_dim: int, layer_names: List[str]) -> None:
- super().__init__()
- self.layers = nn.ModuleList([
- AttentionalPropagation(feature_dim, 4)
- for _ in range(len(layer_names))])
- self.names = layer_names
-
- def forward(self, desc0: torch.Tensor, desc1: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]:
- for layer, name in zip(self.layers, self.names):
- if name == 'cross':
- src0, src1 = desc1, desc0
- else: # if name == 'self':
- src0, src1 = desc0, desc1
- delta0, delta1 = layer(desc0, src0), layer(desc1, src1)
- desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
- return desc0, desc1
-
-
-def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor:
- """ Perform Sinkhorn Normalization in Log-space for stability"""
- u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
- for _ in range(iters):
- u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
- v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
- return Z + u.unsqueeze(2) + v.unsqueeze(1)
-
-
-def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor:
- """ Perform Differentiable Optimal Transport in Log-space for stability"""
- b, m, n = scores.shape
- one = scores.new_tensor(1)
- ms, ns = (m*one).to(scores), (n*one).to(scores)
-
- bins0 = alpha.expand(b, m, 1)
- bins1 = alpha.expand(b, 1, n)
- alpha = alpha.expand(b, 1, 1)
-
- couplings = torch.cat([torch.cat([scores, bins0], -1),
- torch.cat([bins1, alpha], -1)], 1)
-
- norm = - (ms + ns).log()
- log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])
- log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
- log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
-
- Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
- Z = Z - norm # multiply probabilities by M+N
- return Z
-
-
-def arange_like(x, dim: int):
- return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1
-
-
-class SuperGlue(nn.Module):
- """SuperGlue feature matching middle-end
-
- Given two sets of keypoints and locations, we determine the
- correspondences by:
- 1. Keypoint Encoding (normalization + visual feature and location fusion)
- 2. Graph Neural Network with multiple self and cross-attention layers
- 3. Final projection layer
- 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
- 5. Thresholding matrix based on mutual exclusivity and a match_threshold
-
- The correspondence ids use -1 to indicate non-matching points.
-
- Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
- Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
- Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763
-
- """
- default_config = {
- 'descriptor_dim': 256,
- 'weights': 'indoor',
- 'keypoint_encoder': [32, 64, 128, 256],
- 'GNN_layers': ['self', 'cross'] * 9,
- 'sinkhorn_iterations': 100,
- 'match_threshold': 0.2,
- }
-
- def __init__(self, config):
- super().__init__()
- self.config = {**self.default_config, **config}
-
- self.kenc = KeypointEncoder(
- self.config['descriptor_dim'], self.config['keypoint_encoder'])
-
- self.gnn = AttentionalGNN(
- feature_dim=self.config['descriptor_dim'], layer_names=self.config['GNN_layers'])
-
- self.final_proj = nn.Conv1d(
- self.config['descriptor_dim'], self.config['descriptor_dim'],
- kernel_size=1, bias=True)
-
- bin_score = torch.nn.Parameter(torch.tensor(1.))
- self.register_parameter('bin_score', bin_score)
-
- assert self.config['weights'] in ['indoor', 'outdoor']
- path = Path(__file__).parent
- path = path / 'weights/superglue_{}.pth'.format(self.config['weights'])
- self.load_state_dict(torch.load(str(path)))
- print('Loaded SuperGlue model (\"{}\" weights)'.format(
- self.config['weights']))
-
- def forward(self, data):
- """Run SuperGlue on a pair of keypoints and descriptors"""
- desc0, desc1 = data['descriptors0'], data['descriptors1']
- kpts0, kpts1 = data['keypoints0'], data['keypoints1']
-
- if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints
- shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
- return {
- 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int),
- 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int),
- 'matching_scores0': kpts0.new_zeros(shape0),
- 'matching_scores1': kpts1.new_zeros(shape1),
- }
-
- # Keypoint normalization.
- kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
- kpts1 = normalize_keypoints(kpts1, data['image1'].shape)
-
- # Keypoint MLP encoder.
- desc0 = desc0 + self.kenc(kpts0, data['scores0'])
- desc1 = desc1 + self.kenc(kpts1, data['scores1'])
-
- # Multi-layer Transformer network.
- desc0, desc1 = self.gnn(desc0, desc1)
-
- # Final MLP projection.
- mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
-
- # Compute matching descriptor distance.
- scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
- scores = scores / self.config['descriptor_dim']**.5
-
- #print(scores)
- #print(scores.max(), scores.min())
-
- # Run the optimal transport.
- scores = log_optimal_transport(
- scores, self.bin_score,
- iters=self.config['sinkhorn_iterations'])
-
-
-
- # Get the matches with score above "match_threshold".
- max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
- indices0, indices1 = max0.indices, max1.indices
- mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
- mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
- zero = scores.new_tensor(0)
- mscores0 = torch.where(mutual0, max0.values.exp(), zero)
- mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
- valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
- valid1 = mutual1 & valid0.gather(1, indices1)
- indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
- indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
-
- return {
- 'matches0': indices0, # use -1 for invalid match
- 'matches1': indices1, # use -1 for invalid match
- 'matching_scores0': mscores0,
- 'matching_scores1': mscores1,
- }
diff --git a/models/superpoint.py b/models/superpoint.py
deleted file mode 100755
index 8e41192..0000000
--- a/models/superpoint.py
+++ /dev/null
@@ -1,202 +0,0 @@
-# %BANNER_BEGIN%
-# ---------------------------------------------------------------------
-# %COPYRIGHT_BEGIN%
-#
-# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
-#
-# Unpublished Copyright (c) 2020
-# Magic Leap, Inc., All Rights Reserved.
-#
-# NOTICE: All information contained herein is, and remains the property
-# of COMPANY. The intellectual and technical concepts contained herein
-# are proprietary to COMPANY and may be covered by U.S. and Foreign
-# Patents, patents in process, and are protected by trade secret or
-# copyright law. Dissemination of this information or reproduction of
-# this material is strictly forbidden unless prior written permission is
-# obtained from COMPANY. Access to the source code contained herein is
-# hereby forbidden to anyone except current COMPANY employees, managers
-# or contractors who have executed Confidentiality and Non-disclosure
-# agreements explicitly covering such access.
-#
-# The copyright notice above does not evidence any actual or intended
-# publication or disclosure of this source code, which includes
-# information that is confidential and/or proprietary, and is a trade
-# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
-# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
-# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
-# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
-# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
-# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
-# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
-# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
-#
-# %COPYRIGHT_END%
-# ----------------------------------------------------------------------
-# %AUTHORS_BEGIN%
-#
-# Originating Authors: Paul-Edouard Sarlin
-#
-# %AUTHORS_END%
-# --------------------------------------------------------------------*/
-# %BANNER_END%
-
-from pathlib import Path
-import torch
-from torch import nn
-
-def simple_nms(scores, nms_radius: int):
- """ Fast Non-maximum suppression to remove nearby points """
- assert(nms_radius >= 0)
-
- def max_pool(x):
- return torch.nn.functional.max_pool2d(
- x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)
-
- zeros = torch.zeros_like(scores)
- max_mask = scores == max_pool(scores)
- for _ in range(2):
- supp_mask = max_pool(max_mask.float()) > 0
- supp_scores = torch.where(supp_mask, zeros, scores)
- new_max_mask = supp_scores == max_pool(supp_scores)
- max_mask = max_mask | (new_max_mask & (~supp_mask))
- return torch.where(max_mask, scores, zeros)
-
-
-def remove_borders(keypoints, scores, border: int, height: int, width: int):
- """ Removes keypoints too close to the border """
- mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
- mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
- mask = mask_h & mask_w
- return keypoints[mask], scores[mask]
-
-
-def top_k_keypoints(keypoints, scores, k: int):
- if k >= len(keypoints):
- return keypoints, scores
- scores, indices = torch.topk(scores, k, dim=0)
- return keypoints[indices], scores
-
-
-def sample_descriptors(keypoints, descriptors, s: int = 8):
- """ Interpolate descriptors at keypoint locations """
- b, c, h, w = descriptors.shape
- keypoints = keypoints - s / 2 + 0.5
- keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
- ).to(keypoints)[None]
- keypoints = keypoints*2 - 1 # normalize to (-1, 1)
- args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
- descriptors = torch.nn.functional.grid_sample(
- descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
- descriptors = torch.nn.functional.normalize(
- descriptors.reshape(b, c, -1), p=2, dim=1)
- return descriptors
-
-
-class SuperPoint(nn.Module):
- """SuperPoint Convolutional Detector and Descriptor
-
- SuperPoint: Self-Supervised Interest Point Detection and
- Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
- Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
-
- """
- default_config = {
- 'descriptor_dim': 256,
- 'nms_radius': 4,
- 'keypoint_threshold': 0.005,
- 'max_keypoints': -1,
- 'remove_borders': 4,
- }
-
- def __init__(self, config):
- super().__init__()
- self.config = {**self.default_config, **config}
-
- self.relu = nn.ReLU(inplace=True)
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
- c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
-
- self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
- self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
- self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
- self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
- self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
- self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
- self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
- self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
-
- self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
- self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
-
- self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
- self.convDb = nn.Conv2d(
- c5, self.config['descriptor_dim'],
- kernel_size=1, stride=1, padding=0)
-
- path = Path(__file__).parent / 'weights/superpoint_v1.pth'
- self.load_state_dict(torch.load(str(path)))
-
- mk = self.config['max_keypoints']
- if mk == 0 or mk < -1:
- raise ValueError('\"max_keypoints\" must be positive or \"-1\"')
-
- print('Loaded SuperPoint model')
-
- def forward(self, data):
- """ Compute keypoints, scores, descriptors for image """
- # Shared Encoder
- x = self.relu(self.conv1a(data['image']))
- x = self.relu(self.conv1b(x))
- x = self.pool(x)
- x = self.relu(self.conv2a(x))
- x = self.relu(self.conv2b(x))
- x = self.pool(x)
- x = self.relu(self.conv3a(x))
- x = self.relu(self.conv3b(x))
- x = self.pool(x)
- x = self.relu(self.conv4a(x))
- x = self.relu(self.conv4b(x))
-
- # Compute the dense keypoint scores
- cPa = self.relu(self.convPa(x))
- scores = self.convPb(cPa)
- scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
- b, _, h, w = scores.shape
- scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
- scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
- scores = simple_nms(scores, self.config['nms_radius'])
-
- # Extract keypoints
- keypoints = [
- torch.nonzero(s > self.config['keypoint_threshold'])
- for s in scores]
- scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
-
- # Discard keypoints near the image borders
- keypoints, scores = list(zip(*[
- remove_borders(k, s, self.config['remove_borders'], h*8, w*8)
- for k, s in zip(keypoints, scores)]))
-
- # Keep the k keypoints with highest score
- if self.config['max_keypoints'] >= 0:
- keypoints, scores = list(zip(*[
- top_k_keypoints(k, s, self.config['max_keypoints'])
- for k, s in zip(keypoints, scores)]))
-
- # Convert (h, w) to (x, y)
- keypoints = [torch.flip(k, [1]).float() for k in keypoints]
-
- # Compute the dense descriptors
- cDa = self.relu(self.convDa(x))
- descriptors = self.convDb(cDa)
- descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
-
- # Extract descriptors
- descriptors = [sample_descriptors(k[None], d[None], 8)[0]
- for k, d in zip(keypoints, descriptors)]
-
- return {
- 'keypoints': keypoints,
- 'scores': scores,
- 'descriptors': descriptors,
- }
diff --git a/models/utils.py b/models/utils.py
deleted file mode 100755
index 6b4ec97..0000000
--- a/models/utils.py
+++ /dev/null
@@ -1,558 +0,0 @@
-# %BANNER_BEGIN%
-# ---------------------------------------------------------------------
-# %COPYRIGHT_BEGIN%
-#
-# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
-#
-# Unpublished Copyright (c) 2020
-# Magic Leap, Inc., All Rights Reserved.
-#
-# NOTICE: All information contained herein is, and remains the property
-# of COMPANY. The intellectual and technical concepts contained herein
-# are proprietary to COMPANY and may be covered by U.S. and Foreign
-# Patents, patents in process, and are protected by trade secret or
-# copyright law. Dissemination of this information or reproduction of
-# this material is strictly forbidden unless prior written permission is
-# obtained from COMPANY. Access to the source code contained herein is
-# hereby forbidden to anyone except current COMPANY employees, managers
-# or contractors who have executed Confidentiality and Non-disclosure
-# agreements explicitly covering such access.
-#
-# The copyright notice above does not evidence any actual or intended
-# publication or disclosure of this source code, which includes
-# information that is confidential and/or proprietary, and is a trade
-# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
-# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
-# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
-# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
-# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
-# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
-# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
-# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
-#
-# %COPYRIGHT_END%
-# ----------------------------------------------------------------------
-# %AUTHORS_BEGIN%
-#
-# Originating Authors: Paul-Edouard Sarlin
-# Daniel DeTone
-# Tomasz Malisiewicz
-#
-# %AUTHORS_END%
-# --------------------------------------------------------------------*/
-# %BANNER_END%
-
-from pathlib import Path
-import time
-from collections import OrderedDict
-from threading import Thread
-import numpy as np
-import cv2
-import torch
-import matplotlib.pyplot as plt
-import matplotlib
-matplotlib.use('Agg')
-
-
-class AverageTimer:
- """ Class to help manage printing simple timing of code execution. """
-
- def __init__(self, smoothing=0.3, newline=False):
- self.smoothing = smoothing
- self.newline = newline
- self.times = OrderedDict()
- self.will_print = OrderedDict()
- self.reset()
-
- def reset(self):
- now = time.time()
- self.start = now
- self.last_time = now
- for name in self.will_print:
- self.will_print[name] = False
-
- def update(self, name='default'):
- now = time.time()
- dt = now - self.last_time
- if name in self.times:
- dt = self.smoothing * dt + (1 - self.smoothing) * self.times[name]
- self.times[name] = dt
- self.will_print[name] = True
- self.last_time = now
-
- def print(self, text='Timer'):
- total = 0.
- print('[{}]'.format(text), end=' ')
- for key in self.times:
- val = self.times[key]
- if self.will_print[key]:
- print('%s=%.3f' % (key, val), end=' ')
- total += val
- print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ')
- if self.newline:
- print(flush=True)
- else:
- print(end='\r', flush=True)
- self.reset()
-
-
-class VideoStreamer:
- """ Class to help process image streams. Four types of possible inputs:"
- 1.) USB Webcam.
- 2.) An IP camera
- 3.) A directory of images (files in directory matching 'image_glob').
- 4.) A video file, such as an .mp4 or .avi file.
- """
- def __init__(self, basedir, resize, skip, image_glob, max_length=1000000):
- self._ip_grabbed = False
- self._ip_running = False
- self._ip_camera = False
- self._ip_image = None
- self._ip_index = 0
- self.cap = []
- self.camera = True
- self.video_file = False
- self.listing = []
- self.resize = resize
- self.interp = cv2.INTER_AREA
- self.i = 0
- self.skip = skip
- self.max_length = max_length
- if isinstance(basedir, int) or basedir.isdigit():
- print('==> Processing USB webcam input: {}'.format(basedir))
- self.cap = cv2.VideoCapture(int(basedir))
- self.listing = range(0, self.max_length)
- elif basedir.startswith(('http', 'rtsp')):
- print('==> Processing IP camera input: {}'.format(basedir))
- self.cap = cv2.VideoCapture(basedir)
- self.start_ip_camera_thread()
- self._ip_camera = True
- self.listing = range(0, self.max_length)
- elif Path(basedir).is_dir():
- print('==> Processing image directory input: {}'.format(basedir))
- self.listing = list(Path(basedir).glob(image_glob[0]))
- for j in range(1, len(image_glob)):
- image_path = list(Path(basedir).glob(image_glob[j]))
- self.listing = self.listing + image_path
- self.listing.sort()
- self.listing = self.listing[::self.skip]
- self.max_length = np.min([self.max_length, len(self.listing)])
- if self.max_length == 0:
- raise IOError('No images found (maybe bad \'image_glob\' ?)')
- self.listing = self.listing[:self.max_length]
- self.camera = False
- elif Path(basedir).exists():
- print('==> Processing video input: {}'.format(basedir))
- self.cap = cv2.VideoCapture(basedir)
- self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
- num_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
- self.listing = range(0, num_frames)
- self.listing = self.listing[::self.skip]
- self.video_file = True
- self.max_length = np.min([self.max_length, len(self.listing)])
- self.listing = self.listing[:self.max_length]
- else:
- raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir))
- if self.camera and not self.cap.isOpened():
- raise IOError('Could not read camera')
-
- def load_image(self, impath):
- """ Read image as grayscale and resize to img_size.
- Inputs
- impath: Path to input image.
- Returns
- grayim: uint8 numpy array sized H x W.
- """
- grayim = cv2.imread(impath, 0)
- if grayim is None:
- raise Exception('Error reading image %s' % impath)
- w, h = grayim.shape[1], grayim.shape[0]
- w_new, h_new = process_resize(w, h, self.resize)
- grayim = cv2.resize(
- grayim, (w_new, h_new), interpolation=self.interp)
- return grayim
-
- def next_frame(self):
- """ Return the next frame, and increment internal counter.
- Returns
- image: Next H x W image.
- status: True or False depending whether image was loaded.
- """
-
- if self.i == self.max_length:
- return (None, False)
- if self.camera:
-
- if self._ip_camera:
- #Wait for first image, making sure we haven't exited
- while self._ip_grabbed is False and self._ip_exited is False:
- time.sleep(.001)
-
- ret, image = self._ip_grabbed, self._ip_image.copy()
- if ret is False:
- self._ip_running = False
- else:
- ret, image = self.cap.read()
- if ret is False:
- print('VideoStreamer: Cannot get image from camera')
- return (None, False)
- w, h = image.shape[1], image.shape[0]
- if self.video_file:
- self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.listing[self.i])
-
- w_new, h_new = process_resize(w, h, self.resize)
- image = cv2.resize(image, (w_new, h_new),
- interpolation=self.interp)
- image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
- else:
- image_file = str(self.listing[self.i])
- image = self.load_image(image_file)
- self.i = self.i + 1
- return (image, True)
-
- def start_ip_camera_thread(self):
- self._ip_thread = Thread(target=self.update_ip_camera, args=())
- self._ip_running = True
- self._ip_thread.start()
- self._ip_exited = False
- return self
-
- def update_ip_camera(self):
- while self._ip_running:
- ret, img = self.cap.read()
- if ret is False:
- self._ip_running = False
- self._ip_exited = True
- self._ip_grabbed = False
- return
-
- self._ip_image = img
- self._ip_grabbed = ret
- self._ip_index += 1
- #print('IPCAMERA THREAD got frame {}'.format(self._ip_index))
-
-
- def cleanup(self):
- self._ip_running = False
-
-# --- PREPROCESSING ---
-
-def process_resize(w, h, resize):
- assert(len(resize) > 0 and len(resize) <= 2)
- if len(resize) == 1 and resize[0] > -1:
- scale = resize[0] / max(h, w)
- w_new, h_new = int(round(w*scale)), int(round(h*scale))
- elif len(resize) == 1 and resize[0] == -1:
- w_new, h_new = w, h
- else: # len(resize) == 2:
- w_new, h_new = resize[0], resize[1]
-
- # Issue warning if resolution is too small or too large.
- if max(w_new, h_new) < 160:
- print('Warning: input resolution is very small, results may vary')
- elif max(w_new, h_new) > 2000:
- print('Warning: input resolution is very large, results may vary')
-
- return w_new, h_new
-
-
-def frame2tensor(frame):
- return torch.from_numpy(frame/255.).float()[None, None].cuda()
-
-
-def read_image(path, resize, rotation, resize_float):
- image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
- if image is None:
- return None, None, None
- w, h = image.shape[1], image.shape[0]
- w_new, h_new = process_resize(w, h, resize)
- scales = (float(w) / float(w_new), float(h) / float(h_new))
-
- if resize_float:
- image = cv2.resize(image.astype('float32'), (w_new, h_new))
- else:
- image = cv2.resize(image, (w_new, h_new)).astype('float32')
-
- if rotation != 0:
- image = np.rot90(image, k=rotation)
- if rotation % 2:
- scales = scales[::-1]
-
- inp = frame2tensor(image)
- return image, inp, scales
-
-
-
-def read_image_modified(image, resize, resize_float):
- if image is None:
- return None, None, None
- w, h = image.shape[1], image.shape[0]
- w_new, h_new = process_resize(w, h, resize)
- scales = (float(w) / float(w_new), float(h) / float(h_new))
- if resize_float:
- image = cv2.resize(image.astype('float32'), (w_new, h_new))
- else:
- image = cv2.resize(image, (w_new, h_new)).astype('float32')
- return image
-# --- GEOMETRY ---
-
-
-def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
- if len(kpts0) < 5:
- return None
-
- f_mean = np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]])
- norm_thresh = thresh / f_mean
-
- kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
- kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
-
- E, mask = cv2.findEssentialMat(
- kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf,
- method=cv2.RANSAC)
-
- assert E is not None
-
- best_num_inliers = 0
- ret = None
- for _E in np.split(E, len(E) / 3):
- n, R, t, mask_new = cv2.recoverPose(
- _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
- if n > best_num_inliers:
- best_num_inliers = n
- ret = (R, t[:, 0], mask.ravel() > 0)
- return ret
-
-
-def rotate_intrinsics(K, image_shape, rot):
- """image_shape is the shape of the image after rotation"""
- assert rot <= 3
- h, w = image_shape[:2][::-1 if (rot % 2) else 1]
- fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
- rot = rot % 4
- if rot == 1:
- return np.array([[fy, 0., cy],
- [0., fx, w-1-cx],
- [0., 0., 1.]], dtype=K.dtype)
- elif rot == 2:
- return np.array([[fx, 0., w-1-cx],
- [0., fy, h-1-cy],
- [0., 0., 1.]], dtype=K.dtype)
- else: # if rot == 3:
- return np.array([[fy, 0., h-1-cy],
- [0., fx, cx],
- [0., 0., 1.]], dtype=K.dtype)
-
-
-def rotate_pose_inplane(i_T_w, rot):
- rotation_matrices = [
- np.array([[np.cos(r), -np.sin(r), 0., 0.],
- [np.sin(r), np.cos(r), 0., 0.],
- [0., 0., 1., 0.],
- [0., 0., 0., 1.]], dtype=np.float32)
- for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
- ]
- return np.dot(rotation_matrices[rot], i_T_w)
-
-
-def scale_intrinsics(K, scales):
- scales = np.diag([1./scales[0], 1./scales[1], 1.])
- return np.dot(scales, K)
-
-
-def to_homogeneous(points):
- return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
-
-
-def compute_epipolar_error(kpts0, kpts1, T_0to1, K0, K1):
- kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
- kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
- kpts0 = to_homogeneous(kpts0)
- kpts1 = to_homogeneous(kpts1)
-
- t0, t1, t2 = T_0to1[:3, 3]
- t_skew = np.array([
- [0, -t2, t1],
- [t2, 0, -t0],
- [-t1, t0, 0]
- ])
- E = t_skew @ T_0to1[:3, :3]
-
- Ep0 = kpts0 @ E.T # N x 3
- p1Ep0 = np.sum(kpts1 * Ep0, -1) # N
- Etp1 = kpts1 @ E # N x 3
- d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2)
- + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2))
- return d
-
-
-def angle_error_mat(R1, R2):
- cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
- cos = np.clip(cos, -1., 1.) # numercial errors can make it out of bounds
- return np.rad2deg(np.abs(np.arccos(cos)))
-
-
-def angle_error_vec(v1, v2):
- n = np.linalg.norm(v1) * np.linalg.norm(v2)
- return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
-
-
-def compute_pose_error(T_0to1, R, t):
- R_gt = T_0to1[:3, :3]
- t_gt = T_0to1[:3, 3]
- error_t = angle_error_vec(t, t_gt)
- error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
- error_R = angle_error_mat(R, R_gt)
- return error_t, error_R
-
-
-def pose_auc(errors, thresholds):
- sort_idx = np.argsort(errors)
- errors = np.array(errors.copy())[sort_idx]
- recall = (np.arange(len(errors)) + 1) / len(errors)
- errors = np.r_[0., errors]
- recall = np.r_[0., recall]
- aucs = []
- for t in thresholds:
- last_index = np.searchsorted(errors, t)
- r = np.r_[recall[:last_index], recall[last_index-1]]
- e = np.r_[errors[:last_index], t]
- aucs.append(np.trapz(r, x=e)/t)
- return aucs
-
-
-# --- VISUALIZATION ---
-
-
-def plot_image_pair(imgs, dpi=100, size=6, pad=.5):
- n = len(imgs)
- assert n == 2, 'number of images must be two'
- figsize = (size*n, size*3/4) if size is not None else None
- _, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
- for i in range(n):
- ax[i].imshow(imgs[i], cmap=plt.get_cmap('gray'), vmin=0, vmax=255)
- ax[i].get_yaxis().set_ticks([])
- ax[i].get_xaxis().set_ticks([])
- for spine in ax[i].spines.values(): # remove frame
- spine.set_visible(False)
- plt.tight_layout(pad=pad)
-
-
-def plot_keypoints(kpts0, kpts1, color='w', ps=2):
- ax = plt.gcf().axes
- ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
- ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
-
-
-def plot_matches(kpts0, kpts1, color, lw=1.5, ps=4):
- fig = plt.gcf()
- ax = fig.axes
- fig.canvas.draw()
-
- transFigure = fig.transFigure.inverted()
- fkpts0 = transFigure.transform(ax[0].transData.transform(kpts0))
- fkpts1 = transFigure.transform(ax[1].transData.transform(kpts1))
-
- fig.lines = [matplotlib.lines.Line2D(
- (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), zorder=1,
- transform=fig.transFigure, c=color[i], linewidth=lw)
- for i in range(len(kpts0))]
- ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
- ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
-
-
-def make_matching_plot(image0, image1, kpts0, kpts1, mkpts0, mkpts1,
- color, text, path, name0, name1, show_keypoints=False,
- fast_viz=False, opencv_display=False, opencv_title='matches'):
-
- if fast_viz:
- make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, mkpts1,
- color, text, path, show_keypoints, 10,
- opencv_display, opencv_title)
- return
-
- plot_image_pair([image0, image1])
- if show_keypoints:
- plot_keypoints(kpts0, kpts1, color='k', ps=4)
- plot_keypoints(kpts0, kpts1, color='w', ps=2)
- plot_matches(mkpts0, mkpts1, color)
-
- fig = plt.gcf()
- txt_color = 'k' if image0[:100, :150].mean() > 200 else 'w'
- fig.text(
- 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
- fontsize=15, va='top', ha='left', color=txt_color)
-
- txt_color = 'k' if image0[-100:, :150].mean() > 200 else 'w'
- fig.text(
- 0.01, 0.01, name0, transform=fig.axes[0].transAxes,
- fontsize=5, va='bottom', ha='left', color=txt_color)
-
- txt_color = 'k' if image1[-100:, :150].mean() > 200 else 'w'
- fig.text(
- 0.01, 0.01, name1, transform=fig.axes[1].transAxes,
- fontsize=5, va='bottom', ha='left', color=txt_color)
-
- plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
- plt.close()
-
-
-def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0,
- mkpts1, color, text, path=None,
- show_keypoints=False, margin=10,
- opencv_display=False, opencv_title=''):
- H0, W0 = image0.shape
- H1, W1 = image1.shape
- H, W = max(H0, H1), W0 + W1 + margin
-
- out = 255*np.ones((H, W), np.uint8)
- out[:H0, :W0] = image0
- out[:H1, W0+margin:] = image1
- out = np.stack([out]*3, -1)
-
- if show_keypoints:
- kpts0, kpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int)
- white = (255, 255, 255)
- black = (0, 0, 0)
- for x, y in kpts0:
- cv2.circle(out, (x, y), 2, black, -1, lineType=cv2.LINE_AA)
- cv2.circle(out, (x, y), 1, white, -1, lineType=cv2.LINE_AA)
- for x, y in kpts1:
- cv2.circle(out, (x + margin + W0, y), 2, black, -1,
- lineType=cv2.LINE_AA)
- cv2.circle(out, (x + margin + W0, y), 1, white, -1,
- lineType=cv2.LINE_AA)
-
- mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
- color = (np.array(color[:, :3])*255).astype(int)[:, ::-1]
- for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color):
- c = c.tolist()
- cv2.line(out, (x0, y0), (x1 + margin + W0, y1),
- color=c, thickness=1, lineType=cv2.LINE_AA)
- # display line end-points as circles
- cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA)
- cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1,
- lineType=cv2.LINE_AA)
-
- Ht = int(H * 30 / 480) # text height
- txt_color_fg = (255, 255, 255)
- txt_color_bg = (0, 0, 0)
- for i, t in enumerate(text):
- cv2.putText(out, t, (10, Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
- H*1.0/480, txt_color_bg, 2, cv2.LINE_AA)
- cv2.putText(out, t, (10, Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
- H*1.0/480, txt_color_fg, 1, cv2.LINE_AA)
-
- if path is not None:
- cv2.imwrite(str(path), out)
-
- if opencv_display:
- cv2.imshow(opencv_title, out)
- cv2.waitKey(1)
-
- return out
-
-
-def error_colormap(x):
- return np.clip(
- np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)], -1), 0, 1)
diff --git a/sjlee/IMC.py b/sjlee/IMC.py
deleted file mode 100644
index d414952..0000000
--- a/sjlee/IMC.py
+++ /dev/null
@@ -1,208 +0,0 @@
-
-import os
-import sys
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-import numpy as np
-from functools import partial
-
-from pydoc import source_synopsis
-from sjlee.superglue2 import SuperGlue, normalize_keypoints, arange_like, log_sinkhorn_iterations, log_optimal_transport
-
-sys.path.append(os.path.join(os.path.dirname(__file__), 'cats'))
-from cats import TransformerAggregator
-
-def dfs_freeze(model):
- for name, child in model.named_children():
- for param in child.parameters():
- param.requires_grad = False
-
- dfs_freeze(child)
-
-def softmax_with_temperature(x, beta=2, d = 1):
- r'''SFNet: Learning Object-aware Semantic Flow (Lee et al.)'''
- M, _ = x.max(dim=d, keepdim=True)
- x = x - M # subtract maximum value for stability
- exp_x = torch.exp(x/beta)
- exp_x_sum = exp_x.sum(dim=d, keepdim=True)
- return exp_x / exp_x_sum
-
-def single_optimal(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor:
- """ Perform Differentiable Optimal Transport in Log-space for stability"""
- b, m, n = scores.shape
- one = scores.new_tensor(1)
- ms, ns = (m*one).to(scores), (n*one).to(scores)
-
- norm = - (ms + ns).log()
- log_mu = norm.expand(m)
- log_nu = norm.expand(n)
- log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
-
- Z = log_sinkhorn_iterations(scores, log_mu, log_nu, iters)
- Z = Z - norm # multiply probabilities by M+N
- return Z
-
-# positional embedding 필요한가?
-# M * N 크기가 다 다른 문제
-class SimpleSuperCATs(SuperGlue):
- def __init__(self,
- config,
- feature_size=32,
- feature_proj_dim=128,
- depth=4,
- num_heads=4,
- mlp_ratio=4,
- ):
- super().__init__(config)
-
- # freeze superglue's layers
- dfs_freeze(self.kenc)
- dfs_freeze(self.gnn)
- dfs_freeze(self.final_proj)
-
- self.feature_size = feature_size
- self.feature_proj_dim = feature_proj_dim
- self.decoder_embed_dim = self.feature_size ** 2
-
- self.decoder = TransformerAggregator(
- img_size=self.feature_size, embed_dim=self.decoder_embed_dim, depth=depth, num_heads=num_heads,
- mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
- num_hyperpixel=1
- )
-
- self.num_heads = num_heads
- self.mask = None
-
- def forward(self, data):
- """Run SuperGlue on a pair of keypoints and descriptors"""
- desc0, desc1 = data['descriptors0'], data['descriptors1']
- kpts0, kpts1 = data['keypoints0'], data['keypoints1']
-
- desc0 = desc0.transpose(0,1)
- desc1 = desc1.transpose(0,1)
- kpts0 = torch.reshape(kpts0, (1, -1, 2))
- kpts1 = torch.reshape(kpts1, (1, -1, 2))
-
- if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints
- shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
- return [],{
- 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0],
- 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0],
- 'matching_scores0': kpts0.new_zeros(shape0)[0],
- 'matching_scores1': kpts1.new_zeros(shape1)[0],
- 'skip_train': True
- }
-
- # Keypoint normalization.
- kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
- kpts1 = normalize_keypoints(kpts1, data['image1'].shape)
-
- # Keypoint MLP encoder.
- desc0 = desc0 + self.kenc(kpts0, torch.transpose(data['scores0'], 0, 1))
- desc1 = desc1 + self.kenc(kpts1, torch.transpose(data['scores1'], 0, 1))
-
- # Multi-layer Transformer network.
- desc0, desc1 = self.gnn(desc0, desc1)
-
- # Final MLP projection.
- mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
-
- # Compute matching descriptor distance.
- scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
- scores = scores / self.config['descriptor_dim']**.5
-
- b, m, n = scores.shape
- max_keypoints = self.feature_size ** 2
- if m + n < max_keypoints *2:
- p2d = (0, max_keypoints-n, 0, max_keypoints-m)
- pad = scores.min().item()
- scores = F.pad(scores, p2d, 'constant', pad).type(scores.dtype)
- self.mask = (scores == pad).expand(1, self.num_heads, max_keypoints, max_keypoints)
-
- scores = self.decoder(scores[:, None, :, :], self.mask)
- scores = scores[:, :m, :n]
- scores = log_optimal_transport(
- scores, self.bin_score,
- iters=1)
-
- # Get the matches with score above "match_threshold".
- max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
- indices0, indices1 = max0.indices, max1.indices
- mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
- mutual1 = arange_like(indices1 , 1)[None] == indices0.gather(1, indices1)
- zero = scores.new_tensor(0)
- mscores0 = torch.where(mutual0, max0.values.exp(), zero)
- mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
- valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
- valid1 = mutual1 & valid0.gather(1, indices1)
- indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
- indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
-
- return scores, {
- 'matches0': indices0[0], # use -1 for invalid match
- 'matches1': indices1[0], # use -1 for invalid match
- 'matching_scores0': mscores0[0],
- 'matching_scores1': mscores1[0],
- 'skip_train': False
- }
-
-
-if __name__ == '__main__':
- from superpoint import SuperPoint
-
- config = {
- 'superpoint': {
- 'nms_radius': 4,
- 'keypoint_threshold': 0.005,
- 'max_keypoints': 1024
- },
- 'superglue': {
- 'weights': 'outdoor',
- 'sinkhorn_iterations': 20,
- 'match_threshold':0.2
- }
- }
-
- """
- data = {
- 'image0': torch.randn(1, 1, 512, 512),
- 'image1': torch.randn(1, 1, 512, 512)
- }
-
- superpoint = SuperPoint(config.get('superpoint', {}))
-
- output1 = superpoint({'image': data['image0']})
- output2 = superpoint({'image': data['image1']})
-
- pred = {}
-
- pred = {**pred, **{k+'0': v for k, v in output1.items()}}
- pred = {**pred, **{k+'1': v for k, v in output2.items()}}
-
- data = {**data, **pred}
-
- for k in data:
- if isinstance(data[k], (list, tuple)):
- data[k] = torch.stack(data[k])
- """
-
- pred = {
- 'keypoints0' : torch.randn(1, 1, 484, 2),
- 'keypoints1' : torch.randn(1, 1, 484, 2),
- 'descriptors0' : torch.randn(256, 1, 484),
- 'descriptors1' : torch.randn(256, 1, 484),
- 'scores0' : torch.randn(484, 1),
- 'scores1' : torch.randn(484, 1),
- 'image0' : torch.randn(1, 1, 512, 512),
- 'image1' : torch.randn(1, 1, 512, 512),
- # 'all_matches' : torch.randn(2, 1, 1248)
- }
-
- superglue = SimpleSuperCATs(config.get('superglue', {}))
- scores, output = superglue(pred)
-
- # loss = loss_superglue(scores, pred['all_matches'].permute(1, 2, 0))
- # print(loss)
\ No newline at end of file
diff --git a/sjlee/__pycache__/IMC.cpython-38.pyc b/sjlee/__pycache__/IMC.cpython-38.pyc
deleted file mode 100644
index 0c0fbaf..0000000
Binary files a/sjlee/__pycache__/IMC.cpython-38.pyc and /dev/null differ
diff --git a/sjlee/__pycache__/IMC.cpython-39.pyc b/sjlee/__pycache__/IMC.cpython-39.pyc
deleted file mode 100644
index 98af12e..0000000
Binary files a/sjlee/__pycache__/IMC.cpython-39.pyc and /dev/null differ
diff --git a/sjlee/__pycache__/loss.cpython-38.pyc b/sjlee/__pycache__/loss.cpython-38.pyc
deleted file mode 100644
index 9a08dc8..0000000
Binary files a/sjlee/__pycache__/loss.cpython-38.pyc and /dev/null differ
diff --git a/sjlee/__pycache__/loss.cpython-39.pyc b/sjlee/__pycache__/loss.cpython-39.pyc
deleted file mode 100644
index 2a65989..0000000
Binary files a/sjlee/__pycache__/loss.cpython-39.pyc and /dev/null differ
diff --git a/sjlee/__pycache__/superglue.cpython-38.pyc b/sjlee/__pycache__/superglue.cpython-38.pyc
deleted file mode 100644
index 3acd4d8..0000000
Binary files a/sjlee/__pycache__/superglue.cpython-38.pyc and /dev/null differ
diff --git a/sjlee/__pycache__/superglue2.cpython-38.pyc b/sjlee/__pycache__/superglue2.cpython-38.pyc
deleted file mode 100644
index 79679e0..0000000
Binary files a/sjlee/__pycache__/superglue2.cpython-38.pyc and /dev/null differ
diff --git a/sjlee/__pycache__/superglue2.cpython-39.pyc b/sjlee/__pycache__/superglue2.cpython-39.pyc
deleted file mode 100644
index dc8a91a..0000000
Binary files a/sjlee/__pycache__/superglue2.cpython-39.pyc and /dev/null differ
diff --git a/sjlee/__pycache__/superpoint.cpython-38.pyc b/sjlee/__pycache__/superpoint.cpython-38.pyc
deleted file mode 100644
index 262ba3e..0000000
Binary files a/sjlee/__pycache__/superpoint.cpython-38.pyc and /dev/null differ
diff --git a/sjlee/cats/__pycache__/cats.cpython-38.pyc b/sjlee/cats/__pycache__/cats.cpython-38.pyc
deleted file mode 100644
index 3754171..0000000
Binary files a/sjlee/cats/__pycache__/cats.cpython-38.pyc and /dev/null differ
diff --git a/sjlee/cats/__pycache__/cats.cpython-39.pyc b/sjlee/cats/__pycache__/cats.cpython-39.pyc
deleted file mode 100644
index aeb5fba..0000000
Binary files a/sjlee/cats/__pycache__/cats.cpython-39.pyc and /dev/null differ
diff --git a/sjlee/cats/__pycache__/mod.cpython-38.pyc b/sjlee/cats/__pycache__/mod.cpython-38.pyc
deleted file mode 100644
index 99fe958..0000000
Binary files a/sjlee/cats/__pycache__/mod.cpython-38.pyc and /dev/null differ
diff --git a/sjlee/cats/__pycache__/mod.cpython-39.pyc b/sjlee/cats/__pycache__/mod.cpython-39.pyc
deleted file mode 100644
index eb7be3e..0000000
Binary files a/sjlee/cats/__pycache__/mod.cpython-39.pyc and /dev/null differ
diff --git a/sjlee/cats/cats.py b/sjlee/cats/cats.py
deleted file mode 100644
index ec9e200..0000000
--- a/sjlee/cats/cats.py
+++ /dev/null
@@ -1,408 +0,0 @@
-import os
-import sys
-from operator import add
-from functools import reduce, partial
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import numpy as np
-
-import torchvision.models as models
-
-from feature_backbones import resnet
-from mod import FeatureL2Norm, unnormalise_and_convert_mapping_to_flow
-
-'''
-Modified timm library Vision Transformer implementation
-https://github.com/rwightman/pytorch-image-models
-'''
-
-# ================= timm functions START ================= #
-
-import math
-import warnings
-
-def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
- 'survival rate' as the argument.
- """
- if drop_prob == 0. or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
- if keep_prob > 0.0 and scale_by_keep:
- random_tensor.div_(keep_prob)
- return x * random_tensor
-
-def _no_grad_trunc_normal_(tensor, mean, std, a, b):
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
- def norm_cdf(x):
- # Computes standard normal cumulative distribution function
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
-
- if (mean < a - 2 * std) or (mean > b + 2 * std):
- warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
- "The distribution of values may be incorrect.",
- stacklevel=2)
-
- with torch.no_grad():
- # Values are generated by using a truncated uniform distribution and
- # then using the inverse CDF for the normal distribution.
- # Get upper and lower cdf values
- l = norm_cdf((a - mean) / std)
- u = norm_cdf((b - mean) / std)
-
- # Uniformly fill tensor with values from [l, u], then translate to
- # [2l-1, 2u-1].
- tensor.uniform_(2 * l - 1, 2 * u - 1)
-
- # Use inverse cdf transform for normal distribution to get truncated
- # standard normal
- tensor.erfinv_()
-
- # Transform to proper mean, std
- tensor.mul_(std * math.sqrt(2.))
- tensor.add_(mean)
-
- # Clamp to ensure it's in the proper range
- tensor.clamp_(min=a, max=b)
- return tensor
-
-
-def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
- # type: (Tensor, float, float, float, float) -> Tensor
- r"""Fills the input Tensor with values drawn from a truncated
- normal distribution. The values are effectively drawn from the
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
- with values outside :math:`[a, b]` redrawn until they are within
- the bounds. The method used for generating the random values works
- best when :math:`a \leq \text{mean} \leq b`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- mean: the mean of the normal distribution
- std: the standard deviation of the normal distribution
- a: the minimum cutoff value
- b: the maximum cutoff value
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.trunc_normal_(w)
- """
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
-
-# ================= timm functions END================= #
-
-
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-class Attention(nn.Module):
- def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
- super().__init__()
- self.num_heads = num_heads
- head_dim = dim // num_heads
- # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
- self.scale = qk_scale or head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x, mask=None):
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- attn = (q @ k.transpose(-2, -1)) * self.scale
- if mask is not None:
- attn[mask] = -1e-9
-
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-class MultiscaleBlock(nn.Module):
-
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
- super().__init__()
- self.attn = Attention(
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
- self.attn_multiscale = Attention(
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm1 = norm_layer(dim)
- self.norm2 = norm_layer(dim)
- self.norm3 = norm_layer(dim)
- self.norm4 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
- self.mlp2 = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- def forward(self, inputs):
- '''
- Multi-level aggregation
- '''
- x, mask = inputs
- B, N, H, W = x.shape
- if N == 1:
- x = x.flatten(0, 1)
- x = self.norm1(x)
- x = x + self.drop_path(self.attn(self.norm1(x), mask=mask))
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- return x.view(B, N, H, W), mask
- x = x.flatten(0, 1)
- x = x + self.drop_path(self.attn(self.norm1(x)))
- x = x + self.drop_path(self.mlp2(self.norm4(x)))
- x = x.view(B, N, H, W).transpose(1, 2).flatten(0, 1)
- x = x + self.drop_path(self.attn_multiscale(self.norm3(x)))
- x = x.view(B, H, N, W).transpose(1, 2).flatten(0, 1)
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- x = x.view(B, N, H, W)
- return x
-
-
-class TransformerAggregator(nn.Module):
- def __init__(self, num_hyperpixel, img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None):
- super().__init__()
- self.img_size = img_size
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
- norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
-
- self.pos_embed_x = nn.Parameter(torch.zeros(1, num_hyperpixel, 1, img_size, embed_dim // 2))
- self.pos_embed_y = nn.Parameter(torch.zeros(1, num_hyperpixel, img_size, 1, embed_dim // 2))
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
- self.blocks = nn.Sequential(*[
- MultiscaleBlock(
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
- for i in range(depth)])
-
- self.proj = nn.Linear(embed_dim, img_size ** 2)
- self.norm = norm_layer(embed_dim)
-
- trunc_normal_(self.pos_embed_x, std=.02)
- trunc_normal_(self.pos_embed_y, std=.02)
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- def forward(self, corr, mask=None):
- B = corr.shape[0]
- x = corr.clone()
-
- pos_embed = torch.cat((self.pos_embed_x.repeat(1, 1, self.img_size, 1, 1), self.pos_embed_y.repeat(1, 1, 1, self.img_size, 1)), dim=4)
- pos_embed = pos_embed.flatten(2, 3)
-
- x = x.transpose(-1, -2) + pos_embed
- x = self.proj(self.blocks((x, mask.transpose(-1, -2)))[0]).transpose(-1, -2) + corr # swapping the axis for swapping self-attention.
-
- x = x + pos_embed
- x = self.proj(self.blocks((x, mask))[0]) + corr
-
- return x.mean(1)
-
-
-class FeatureExtractionHyperPixel(nn.Module):
- def __init__(self, hyperpixel_ids, feature_size, freeze=True):
- super().__init__()
- self.backbone = resnet.resnet101(pretrained=True)
- self.feature_size = feature_size
- if freeze:
- for param in self.backbone.parameters():
- param.requires_grad = False
- nbottlenecks = [3, 4, 23, 3]
- self.bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), nbottlenecks)))
- self.layer_ids = reduce(add, [[i + 1] * x for i, x in enumerate(nbottlenecks)])
- self.hyperpixel_ids = hyperpixel_ids
-
-
- def forward(self, img):
- r"""Extract desired a list of intermediate features"""
-
- feats = []
-
- # Layer 0
- feat = self.backbone.conv1.forward(img)
- feat = self.backbone.bn1.forward(feat)
- feat = self.backbone.relu.forward(feat)
- feat = self.backbone.maxpool.forward(feat)
- if 0 in self.hyperpixel_ids:
- feats.append(feat.clone())
-
- # Layer 1-4
- for hid, (bid, lid) in enumerate(zip(self.bottleneck_ids, self.layer_ids)):
- res = feat
- feat = self.backbone.__getattr__('layer%d' % lid)[bid].conv1.forward(feat)
- feat = self.backbone.__getattr__('layer%d' % lid)[bid].bn1.forward(feat)
- feat = self.backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
- feat = self.backbone.__getattr__('layer%d' % lid)[bid].conv2.forward(feat)
- feat = self.backbone.__getattr__('layer%d' % lid)[bid].bn2.forward(feat)
- feat = self.backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
- feat = self.backbone.__getattr__('layer%d' % lid)[bid].conv3.forward(feat)
- feat = self.backbone.__getattr__('layer%d' % lid)[bid].bn3.forward(feat)
-
- if bid == 0:
- res = self.backbone.__getattr__('layer%d' % lid)[bid].downsample.forward(res)
-
- feat += res
-
- if hid + 1 in self.hyperpixel_ids:
- feats.append(feat.clone())
- #if hid + 1 == max(self.hyperpixel_ids):
- # break
- feat = self.backbone.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
-
- # Up-sample & concatenate features to construct a hyperimage
-
- """
- for idx, feat in enumerate(feats):
- feats[idx] = F.interpolate(feat, self.feature_size, None, 'bilinear', True)
- """
-
- return feats
-
-
-class CATs(nn.Module):
- def __init__(self,
- feature_size=16,
- feature_proj_dim=128,
- depth=4,
- num_heads=6,
- mlp_ratio=4,
- hyperpixel_ids=[0,8,20,21,26,28,29,30],
- freeze=True):
- super().__init__()
- self.feature_size = feature_size
- self.feature_proj_dim = feature_proj_dim
- self.decoder_embed_dim = self.feature_size ** 2 + self.feature_proj_dim
-
- channels = [64] + [256] * 3 + [512] * 4 + [1024] * 23 + [2048] * 3
-
- self.feature_extraction = FeatureExtractionHyperPixel(hyperpixel_ids, feature_size, freeze)
- self.proj = nn.ModuleList([
- nn.Linear(channels[i], self.feature_proj_dim) for i in hyperpixel_ids
- ])
-
- self.decoder = TransformerAggregator(
- img_size=self.feature_size, embed_dim=self.decoder_embed_dim, depth=depth, num_heads=num_heads,
- mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
- num_hyperpixel=len(hyperpixel_ids))
-
- self.l2norm = FeatureL2Norm()
-
- self.x_normal = np.linspace(-1,1,self.feature_size)
- self.x_normal = nn.Parameter(torch.tensor(self.x_normal, dtype=torch.float, requires_grad=False))
- self.y_normal = np.linspace(-1,1,self.feature_size)
- self.y_normal = nn.Parameter(torch.tensor(self.y_normal, dtype=torch.float, requires_grad=False))
-
- def softmax_with_temperature(self, x, beta, d = 1):
- r'''SFNet: Learning Object-aware Semantic Flow (Lee et al.)'''
- M, _ = x.max(dim=d, keepdim=True)
- x = x - M # subtract maximum value for stability
- exp_x = torch.exp(x/beta)
- exp_x_sum = exp_x.sum(dim=d, keepdim=True)
- return exp_x / exp_x_sum
-
- def soft_argmax(self, corr, beta=0.02):
- r'''SFNet: Learning Object-aware Semantic Flow (Lee et al.)'''
- b,_,h,w = corr.size()
-
- corr = self.softmax_with_temperature(corr, beta=beta, d=1)
- corr = corr.view(-1,h,w,h,w) # (target hxw) x (source hxw)
-
- grid_x = corr.sum(dim=1, keepdim=False) # marginalize to x-coord.
- x_normal = self.x_normal.expand(b,w)
- x_normal = x_normal.view(b,w,1,1)
- grid_x = (grid_x*x_normal).sum(dim=1, keepdim=True) # b x 1 x h x w
-
- grid_y = corr.sum(dim=2, keepdim=False) # marginalize to y-coord.
- y_normal = self.y_normal.expand(b,h)
- y_normal = y_normal.view(b,h,1,1)
- grid_y = (grid_y*y_normal).sum(dim=1, keepdim=True) # b x 1 x h x w
- return grid_x, grid_y
-
- def mutual_nn_filter(self, correlation_matrix):
- r"""Mutual nearest neighbor filtering (Rocco et al. NeurIPS'18)"""
- corr_src_max = torch.max(correlation_matrix, dim=3, keepdim=True)[0]
- corr_trg_max = torch.max(correlation_matrix, dim=2, keepdim=True)[0]
- corr_src_max[corr_src_max == 0] += 1e-30
- corr_trg_max[corr_trg_max == 0] += 1e-30
-
- corr_src = correlation_matrix / corr_src_max
- corr_trg = correlation_matrix / corr_trg_max
-
- return correlation_matrix * (corr_src * corr_trg)
-
- def corr(self, src, trg):
- return src.flatten(2).transpose(-1, -2) @ trg.flatten(2)
-
- def forward(self, target, source):
- B, _, H, W = target.size()
-
- src_feats = self.feature_extraction(source)
- tgt_feats = self.feature_extraction(target)
-
- corrs = []
- src_feats_proj = []
- tgt_feats_proj = []
- for i, (src, tgt) in enumerate(zip(src_feats, tgt_feats)):
- corr = self.corr(self.l2norm(src), self.l2norm(tgt))
- corrs.append(corr)
- src_feats_proj.append(self.proj[i](src.flatten(2).transpose(-1, -2)))
- tgt_feats_proj.append(self.proj[i](tgt.flatten(2).transpose(-1, -2)))
-
- src_feats = torch.stack(src_feats_proj, dim=1)
- tgt_feats = torch.stack(tgt_feats_proj, dim=1)
- corr = torch.stack(corrs, dim=1)
-
- corr = self.mutual_nn_filter(corr)
-
- refined_corr = self.decoder(corr, src_feats, tgt_feats)
-
- grid_x, grid_y = self.soft_argmax(refined_corr.view(B, -1, self.feature_size, self.feature_size))
-
- flow = torch.cat((grid_x, grid_y), dim=1)
- flow = unnormalise_and_convert_mapping_to_flow(flow)
-
- return flow
diff --git a/sjlee/cats/feature_backbones/__pycache__/resnet.cpython-38.pyc b/sjlee/cats/feature_backbones/__pycache__/resnet.cpython-38.pyc
deleted file mode 100644
index b023d5f..0000000
Binary files a/sjlee/cats/feature_backbones/__pycache__/resnet.cpython-38.pyc and /dev/null differ
diff --git a/sjlee/cats/feature_backbones/__pycache__/resnet.cpython-39.pyc b/sjlee/cats/feature_backbones/__pycache__/resnet.cpython-39.pyc
deleted file mode 100644
index 26d7638..0000000
Binary files a/sjlee/cats/feature_backbones/__pycache__/resnet.cpython-39.pyc and /dev/null differ
diff --git a/sjlee/cats/feature_backbones/resnet.py b/sjlee/cats/feature_backbones/resnet.py
deleted file mode 100644
index 2c94e68..0000000
--- a/sjlee/cats/feature_backbones/resnet.py
+++ /dev/null
@@ -1,342 +0,0 @@
-import torch
-import torch.nn as nn
-#from .utils import load_state_dict_from_url
-try:
- from torch.hub import load_state_dict_from_url
-except ImportError:
- from torch.utils.model_zoo import load_url as load_state_dict_from_url
-
-
-__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
- 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
- 'wide_resnet50_2', 'wide_resnet101_2']
-
-
-model_urls = {
- 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
- 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
- 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
- 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
- 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
- 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
- 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
- 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
- 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
-}
-
-
-def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
- """3x3 convolution with padding"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
- padding=dilation, groups=groups, bias=False, dilation=dilation)
-
-
-def conv1x1(in_planes, out_planes, stride=1):
- """1x1 convolution"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-
-class BasicBlock(nn.Module):
- expansion = 1
- __constants__ = ['downsample']
-
- def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
- base_width=64, dilation=1, norm_layer=None):
- super(BasicBlock, self).__init__()
- if norm_layer is None:
- norm_layer = nn.BatchNorm2d
- if groups != 1 or base_width != 64:
- raise ValueError('BasicBlock only supports groups=1 and base_width=64')
- if dilation > 1:
- raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
- # Both self.conv1 and self.downsample layers downsample the input when stride != 1
- self.conv1 = conv3x3(inplanes, planes, stride)
- self.bn1 = norm_layer(planes)
- self.relu = nn.ReLU(inplace=True)
- self.conv2 = conv3x3(planes, planes)
- self.bn2 = norm_layer(planes)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- identity = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
-
- if self.downsample is not None:
- identity = self.downsample(x)
-
- out += identity
- out = self.relu(out)
-
- return out
-
-
-class Bottleneck(nn.Module):
- expansion = 4
- __constants__ = ['downsample']
-
- def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
- base_width=64, dilation=1, norm_layer=None):
- super(Bottleneck, self).__init__()
- if norm_layer is None:
- norm_layer = nn.BatchNorm2d
- width = int(planes * (base_width / 64.)) * groups
- # Both self.conv2 and self.downsample layers downsample the input when stride != 1
- self.conv1 = conv1x1(inplanes, width)
- self.bn1 = norm_layer(width)
- self.conv2 = conv3x3(width, width, stride, groups, dilation)
- self.bn2 = norm_layer(width)
- self.conv3 = conv1x1(width, planes * self.expansion)
- self.bn3 = norm_layer(planes * self.expansion)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- identity = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
- out = self.relu(out)
-
- out = self.conv3(out)
- out = self.bn3(out)
-
- if self.downsample is not None:
- identity = self.downsample(x)
-
- out += identity
- out = self.relu(out)
-
- return out
-
-
-class ResNet(nn.Module):
-
- def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
- groups=1, width_per_group=64, replace_stride_with_dilation=None,
- norm_layer=None):
- super(ResNet, self).__init__()
- if norm_layer is None:
- norm_layer = nn.BatchNorm2d
- self._norm_layer = norm_layer
-
- self.inplanes = 64
- self.dilation = 1
- if replace_stride_with_dilation is None:
- # each element in the tuple indicates if we should replace
- # the 2x2 stride with a dilated convolution instead
- replace_stride_with_dilation = [False, False, False]
- if len(replace_stride_with_dilation) != 3:
- raise ValueError("replace_stride_with_dilation should be None "
- "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
- self.groups = groups
- self.base_width = width_per_group
- self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
- bias=False)
- self.bn1 = norm_layer(self.inplanes)
- self.relu = nn.ReLU(inplace=True)
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- self.layer1 = self._make_layer(block, 64, layers[0])
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
- dilate=replace_stride_with_dilation[0])
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
- dilate=replace_stride_with_dilation[1])
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
- dilate=replace_stride_with_dilation[2])
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- self.fc = nn.Linear(512 * block.expansion, num_classes)
-
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
- # Zero-initialize the last BN in each residual branch,
- # so that the residual branch starts with zeros, and each residual block behaves like an identity.
- # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
- if zero_init_residual:
- for m in self.modules():
- if isinstance(m, Bottleneck):
- nn.init.constant_(m.bn3.weight, 0)
- elif isinstance(m, BasicBlock):
- nn.init.constant_(m.bn2.weight, 0)
-
- def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
- norm_layer = self._norm_layer
- downsample = None
- previous_dilation = self.dilation
- if dilate:
- self.dilation *= stride
- stride = 1
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- conv1x1(self.inplanes, planes * block.expansion, stride),
- norm_layer(planes * block.expansion),
- )
-
- layers = []
- layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
- self.base_width, previous_dilation, norm_layer))
- self.inplanes = planes * block.expansion
- for _ in range(1, blocks):
- layers.append(block(self.inplanes, planes, groups=self.groups,
- base_width=self.base_width, dilation=self.dilation,
- norm_layer=norm_layer))
-
- return nn.Sequential(*layers)
-
- def _forward(self, x):
- x = self.conv1(x)
- print(x.shape)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.maxpool(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
-
- x = self.avgpool(x)
- x = torch.flatten(x, 1)
- x = self.fc(x)
-
- return x
-
- # Allow for accessing forward method in a inherited class
- forward = _forward
-
-
-def _resnet(arch, block, layers, pretrained, progress, **kwargs):
- model = ResNet(block, layers, **kwargs)
- if pretrained:
- state_dict = load_state_dict_from_url(model_urls[arch],
- progress=progress)
- model.load_state_dict(state_dict)
- return model
-
-
-def resnet18(pretrained=False, progress=True, **kwargs):
- r"""ResNet-18 model from
- `"Deep Residual Learning for Image Recognition"