diff --git a/cotracker/models/core/cotracker/cotracker.py b/cotracker/models/core/cotracker/cotracker.py index 53178fbe..b97d66df 100644 --- a/cotracker/models/core/cotracker/cotracker.py +++ b/cotracker/models/core/cotracker/cotracker.py @@ -8,6 +8,8 @@ import torch.nn as nn import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin + from cotracker.models.core.model_utils import sample_features4d, sample_features5d from cotracker.models.core.embeddings import ( get_2d_embedding, @@ -26,7 +28,11 @@ torch.manual_seed(0) -class CoTracker2(nn.Module): +class CoTracker2(nn.Module, PyTorchModelHubMixin, + library_name="co-tracker", + repo_url="https://github.com/facebookresearch/co-tracker", + license="cc-by-nc-4.0", + tags=["object-tracking"]): def __init__( self, window_len=8,