Skip to content

Add support for training trackastra with SAM2 features#61

Open
anwai98 wants to merge 7 commits intoweigertlab:mainfrom
anwai98:add-training-support-with-sam2-feats
Open

Add support for training trackastra with SAM2 features#61
anwai98 wants to merge 7 commits intoweigertlab:mainfrom
anwai98:add-training-support-with-sam2-feats

Conversation

@anwai98
Copy link
Copy Markdown
Contributor

@anwai98 anwai98 commented Apr 1, 2026

Hi @C-Achard,

Here's are my minimal changes to make training work with SAM2 features.

Let me know how it looks!

PS. In case it helps, here's my yaml config file to train trackastra:

yaml config
# Trackastra finetuning config file for TOIAM dataset (using SAM2 features)
# Run: python /mnt/vast-nhr/home/archit/u12090/trackastra/scripts/train.py -c train_config.yaml

name: toiam_sam2_features
outdir: ./runs

# Data
ndim: 2
input_train:
  - /mnt/vast-nhr/projects/cidas/cca/data/toiam/data/00
  - /mnt/vast-nhr/projects/cidas/cca/data/toiam/data/01
input_val:
  - /mnt/vast-nhr/projects/cidas/cca/data/toiam/data/04
detection_folders:
  - TRA
  - SEG

# Feature backbone (aligned to pretrained model)
features: pretrained_feats_aug
pretrained_feats_model: facebook/sam2.1-hiera-base-plus
pretrained_feats_mode: mean_patches_exact
pretrained_feats_additional_props: regionprops_small
pretrained_n_augs: 15
reduced_pretrained_feat_dim: 128
rotate_features: true

# Finetuning from pretrained
model: /user/archit/u12090/.local/share/trackastra/models/general_2d_w_SAM2_features

# Model architecture (matching pretrained)
d_model: 256
num_encoder_layers: 4
num_decoder_layers: 4
dropout: 0.05
window: 4
attn_dist_mode: v1
causal_norm: none

# Training hyperparameters
epochs: 500
warmup_epochs: 5
train_samples: 32000
batch_size: 16
max_tokens: 2048
weight_decay: 0.01
weight_by_dataset: true

# Augmentation
crop_size:
  - 320
  - 320

# Caching
cachedir: ./runs/.cache

# Logging and other misc. stuff
logger: tensorboard
seed: 42

@anwai98
Copy link
Copy Markdown
Contributor Author

anwai98 commented Apr 1, 2026

Poof, some ruff linting structures are funny haha. All should be working now. Lemme know how it looks @C-Achard

Copy link
Copy Markdown
Contributor

@C-Achard C-Achard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @anwai98, had a quick look and the approach seems reasonable, if you end up requiring changes on the pretrained_feats repo happy to have a look as well.

One thing I noticed is that in train.py, if no model path is given (training from scratch), it will load the basic model from Trackastra, rather than the one from pretrained_feats, since in the inference-only version create() is called only from TrackingTransformer.from_folder and then it would likely crash due to the extra args.
Now you did mention you wanted to fine-tune only but maybe the best is to add some error handling if anyone tries to train a pretrained_feats model from scratch, since the resulting exception will likely look unclear if no guard is added.

Otherwise, I noticed some slightly misleading help strings in the CLI, perhaps have a look at the manuscript for better context on what these options do (I added comments on these with recommended defaults).

Finally, if your next step is to train a model, those previous configs may come in handy for that.

I hope this helps, I'm afraid I cannot test this extensively right now but happy to help further if anything is unclear in the review.

Best,
Cyril

parser.add_argument(
"--pretrained_n_augs",
type=int,
default=15,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note : this may be a bit high, depending on the dataset size. I'd recommend starting with much lower values and leaning on the feature disambiguation to avoid overfitting

"--reduced_pretrained_feat_dim",
type=int,
default=None,
help="Reduce pretrained feature dimension via PCA to this size",
Copy link
Copy Markdown
Contributor

@C-Achard C-Achard Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it does not look like you did explicitely re-implement the PCA dimred I used at some point (but that never made it into the final pipeline), I think this refers to the dim of pretrained features after a single FCL as in https://github.com/C-Achard/trackastra/blob/a238b2cadc8e3b954c4af4afeba6df8faf18be71/trackastra/model/model.py#L296.
So this should not mention PCA, but rather the dim of pretrained features you'd like to feed to the encoder (which gets concatenated to the additional region props)

anwai98 and others added 3 commits April 7, 2026 09:18
Co-authored-by: Cyril Achard <cyril.achard@epfl.ch>
Co-authored-by: Cyril Achard <cyril.achard@epfl.ch>
Co-authored-by: Cyril Achard <cyril.achard@epfl.ch>
@anwai98
Copy link
Copy Markdown
Contributor Author

anwai98 commented Apr 7, 2026

Hi @C-Achard,

Thank you so much for the detailed feedback. I'll check them out later in the evening and come back to you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants