Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/romav2/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class Cfg:
default_factory=lambda: [11, 17]
) # [4, 11, 17, 23] for dinov3 style
weights_path: str | None = None
module_path: str | None = None # Path to local directory containing module

def __new__(cls, cfg: Cfg) -> nn.Module:
partial_wrap = partial(
Expand All @@ -101,9 +102,10 @@ def __new__(cls, cfg: Cfg) -> nn.Module:
normalizer = imagenet
# TODO: this will break in distributed if not available locally
dinov3_vitl16: nn.Module = torch.hub.load(
repo_or_dir="facebookresearch/dinov3:adc254450203739c8149213a7a69d8d905b4fcfa",
repo_or_dir="facebookresearch/dinov3:adc254450203739c8149213a7a69d8d905b4fcfa" if cfg.module_path is None else cfg.module_path,
model="dinov3_vitl16",
pretrained=cfg.weights_path is not None,
source="github" if cfg.module_path is None else "local",
weights=cfg.weights_path,
skip_validation=True,
).to(device)
Expand All @@ -118,7 +120,9 @@ def __new__(cls, cfg: Cfg) -> nn.Module:
normalizer = imagenet

dinov2_vit14: nn.Module = torch.hub.load(
"facebookresearch/dinov2", "dinov2_vitl14"
repo_or_dir="facebookresearch/dinov2" if cfg.module_path is None else cfg.module_path,
model="dinov2_vitl14",
source="github" if cfg.module_path is None else "local"
).to(device)
dinov2_vit14.mask_token = None
layers = _get_layers(cfg.layer_idx, dinov2_vit14)
Expand Down
11 changes: 8 additions & 3 deletions src/romav2/romav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class Cfg:
setting: Setting = "precise"
compile: bool = True
name: str = "RoMa v2"
weights: dict = None

# settings
H_lr: int
Expand All @@ -94,10 +95,14 @@ def __init__(self, cfg: Cfg | None = None):
if cfg is None:
# default
cfg = RoMaV2.Cfg()

if cfg.weights is None:
weights = torch.hub.load_state_dict_from_url(
"https://github.com/Parskatt/RoMaV2/releases/download/weights/romav2.pt"
)
else:
weights = cfg.weights

weights = torch.hub.load_state_dict_from_url(
"https://github.com/Parskatt/RoMaV2/releases/download/weights/romav2.pt"
)
self.f = Descriptor(cfg.descriptor)
self.matcher = Matcher(cfg.matcher)
self.cfg = cfg
Expand Down