diff --git a/nets/net34.py b/nets/net34.py index 1e25f5d..60f97c3 100644 --- a/nets/net34.py +++ b/nets/net34.py @@ -6,7 +6,14 @@ from nets.blocks import CNBlockConfig, ConvNeXt, conv1x1, RelUpdateBlock, InputPadder, CorrBlock, BasicEncoder -class Net(nn.Module): +from huggingface_hub import PyTorchModelHubMixin + + +class Net(nn.Module, PyTorchModelHubMixin, + repo_url="https://github.com/aharley/alltracker", + paper_url="https://huggingface.co/papers/2506.07310", + tags=["tracking"], + license="mit"): def __init__( self, seqlen, diff --git a/requirements.txt b/requirements.txt index 727ddb9..058d29e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ scikit-learn==1.5.2 scikit-image==0.24.0 tensorboardX==2.6.2.2 prettytable==3.12.0 +huggingface-hub==0.33.1