From 1e35f911a201c484d646dbc4bbaaba90fe51f55c Mon Sep 17 00:00:00 2001 From: Niels Date: Thu, 24 Apr 2025 15:07:06 +0200 Subject: [PATCH] Add hf_hub_download --- README.md | 8 +------- dinov2/vision_transformer.py | 9 ++++++++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 5025b59..c62b27c 100644 --- a/README.md +++ b/README.md @@ -134,13 +134,7 @@ transform = transforms.Compose([ ]) # Load model -model = webssl_dino1b_full2b_224() - -# Load weights -checkpoint_path = "path/to/downloaded/weights.pth" -state_dict = torch.load(checkpoint_path, map_location="cpu") -msg = model.load_state_dict(state_dict, strict=False) -print(f"Loaded weights: {msg}") +model = webssl_dino1b_full2b_224(pretrained=True) model.cuda().eval() # Process an image diff --git a/dinov2/vision_transformer.py b/dinov2/vision_transformer.py index 8788751..bc0371d 100644 --- a/dinov2/vision_transformer.py +++ b/dinov2/vision_transformer.py @@ -16,6 +16,8 @@ import torch.utils.checkpoint from torch.nn.init import trunc_normal_ +from huggingface_hub import hf_hub_download + from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block @@ -351,7 +353,7 @@ def init_weights_vit_timm(module: nn.Module, name: str = ""): if module.bias is not None: nn.init.zeros_(module.bias) -def webssl_dino300m_full2b_224(img_size=224, patch_size=14, num_register_tokens=0, **kwargs): +def webssl_dino300m_full2b_224(img_size=224, patch_size=14, num_register_tokens=0, pretrained=True, **kwargs): """ Web-DINO ViT-300M DINOv2's "large" architecture / ViT-L @@ -373,6 +375,11 @@ def webssl_dino300m_full2b_224(img_size=224, patch_size=14, num_register_tokens= num_register_tokens=num_register_tokens, **kwargs, ) + if pretrained: + filepath = hf_hub_download(repo_id="facebook/webssl-dino300m-full2-224-pt", filename="dinov2_vitg_300m.pth") + state_dict = torch.load(filepath, map_location="cpu") + model.load_state_dict(state_dict) + return model