From 1e8a395bf9e901628b3b1686b660e90be19bb8c0 Mon Sep 17 00:00:00 2001 From: Fabien Servant Date: Tue, 10 Feb 2026 09:12:54 +0100 Subject: [PATCH] Using local models on filesystem --- src/romav2/features.py | 8 ++++++-- src/romav2/romav2.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/romav2/features.py b/src/romav2/features.py index 6836eaa..92107bf 100644 --- a/src/romav2/features.py +++ b/src/romav2/features.py @@ -88,6 +88,7 @@ class Cfg: default_factory=lambda: [11, 17] ) # [4, 11, 17, 23] for dinov3 style weights_path: str | None = None + module_path: str | None = None # Path to local directory containing module def __new__(cls, cfg: Cfg) -> nn.Module: partial_wrap = partial( @@ -101,9 +102,10 @@ def __new__(cls, cfg: Cfg) -> nn.Module: normalizer = imagenet # TODO: this will break in distributed if not available locally dinov3_vitl16: nn.Module = torch.hub.load( - repo_or_dir="facebookresearch/dinov3:adc254450203739c8149213a7a69d8d905b4fcfa", + repo_or_dir="facebookresearch/dinov3:adc254450203739c8149213a7a69d8d905b4fcfa" if cfg.module_path is None else cfg.module_path, model="dinov3_vitl16", pretrained=cfg.weights_path is not None, + source="github" if cfg.module_path is None else "local", weights=cfg.weights_path, skip_validation=True, ).to(device) @@ -118,7 +120,9 @@ def __new__(cls, cfg: Cfg) -> nn.Module: normalizer = imagenet dinov2_vit14: nn.Module = torch.hub.load( - "facebookresearch/dinov2", "dinov2_vitl14" + repo_or_dir="facebookresearch/dinov2" if cfg.module_path is None else cfg.module_path, + model="dinov2_vitl14", + source="github" if cfg.module_path is None else "local" ).to(device) dinov2_vit14.mask_token = None layers = _get_layers(cfg.layer_idx, dinov2_vit14) diff --git a/src/romav2/romav2.py b/src/romav2/romav2.py index 06aa2cc..d38628c 100644 --- a/src/romav2/romav2.py +++ b/src/romav2/romav2.py @@ -79,6 +79,7 @@ class Cfg: setting: Setting = "precise" compile: bool = True name: str = "RoMa v2" + weights: dict = None # settings H_lr: int @@ -94,10 +95,14 @@ def __init__(self, cfg: Cfg | None = None): if cfg is None: # default cfg = RoMaV2.Cfg() + + if cfg.weights is None: + weights = torch.hub.load_state_dict_from_url( + "https://github.com/Parskatt/RoMaV2/releases/download/weights/romav2.pt" + ) + else: + weights = cfg.weights - weights = torch.hub.load_state_dict_from_url( - "https://github.com/Parskatt/RoMaV2/releases/download/weights/romav2.pt" - ) self.f = Descriptor(cfg.descriptor) self.matcher = Matcher(cfg.matcher) self.cfg = cfg