diff --git a/.cursorrules b/.cursorrules
new file mode 100644
index 0000000..bbb12b8
--- /dev/null
+++ b/.cursorrules
@@ -0,0 +1,15 @@
+You are an expert coder in Python that likes concise, elegant, and self-explanatory code.
+
+Here are some general rules:
+
+
+- Follow PEP 8 naming conventions:
+ - snake_case for functions and variables
+ - PascalCase for classes
+ - UPPER_CASE for constants
+- Write clear comments explaining the rationale behind a complex algorithms but don't be overly verbose
+- Use google style docstrings
+- Use type hints to improve code readability and catch potential errors.
+- Use f-strings for string interpolation
+- Prefer pathlib over os.path
+- Prefer tuples over lists for immutable data
diff --git a/.gitignore b/.gitignore
index e19e224..a9c86a8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -177,3 +177,6 @@ cython_debug/
#.idea/
trackastra/_version.py
+
+CTC_DATA/*
+.vscode/settings.json
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 4902265..57732db 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -8,4 +8,4 @@ repos:
rev: v0.4.6
hooks:
- id: ruff
- args: [--fix, --unsafe-fixes, --preview, --verbose]
+ args: [--fix, --unsafe-fixes, --preview, --verbose, --exit-zero]
diff --git a/README.md b/README.md
index 0fdc5e9..2faf8e7 100644
--- a/README.md
+++ b/README.md
@@ -22,11 +22,13 @@ This repository contains the Python implementation of Trackastra.
Please first set up a Python environment (with Python version 3.10 or higher), preferably via [conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) or [mamba](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html#mamba-install).
+### Simple installation
Trackastra can then be installed from PyPI using `pip`:
```bash
pip install trackastra
```
+### With ILP support
For tracking with an integer linear program (ILP, which is optional)
```bash
conda create --name trackastra python=3.10 --no-default-packages
@@ -34,8 +36,21 @@ conda activate trackastra
conda install -c conda-forge -c gurobi -c funkelab ilpy
pip install "trackastra[ilp]"
```
+
+📄 Development installation
+
+```bash
+conda create --name trackastra python=3.10 --no-default-packages
+conda activate trackastra
+conda install -c conda-forge -c gurobi -c funkelab ilpy
+git clone https://github.com/weigertlab/trackastra.git
+pip install -e "./trackastra[ilp,dev]"
+```
-Notes:
+
+
+📄 Notes/Troubleshooting
+
- For the optional ILP linking, this will install [`motile`](https://funkelab.github.io/motile/index.html) and binaries for two discrete optimizers:
1. The [Gurobi Optimizer](https://www.gurobi.com/). This is a commercial solver, which requires a valid license. Academic licenses are provided for free, see [here](https://www.gurobi.com/academia/academic-program-and-licenses/) for how to obtain one.
@@ -43,6 +58,9 @@ Notes:
2. The [SCIP Optimizer](https://www.scipopt.org/), a free and open source solver. If `motile` does not find a valid Gurobi license, it will fall back to using SCIP.
- On MacOS, installing packages into the conda environment before installing `ilpy` can cause problems.
- 2024-06-07: On Apple M3 chips, you might have to use the nightly build of `torch` and `torchvision`, or worst case build them yourself.
+
+
+
## Usage
@@ -113,7 +131,54 @@ v.add_labels(masks_tracked)
v.add_tracks(data=napari_tracks, graph=napari_tracks_graph)
```
-### Training a model on your own data
+
+
+
+ Fiji (via TrackMate)
+
+
+Trackastra is one of the available trackers in [TrackMate](https://imagej.net/plugins/trackmate/). For installation and usage instructions take a look at this [tutorial](
+https://imagej.net/plugins/trackmate/trackers/trackmate-trackastra).
+
+
+
+ Docker images
+
+
+Some of our models are available as docker images on [Docker Hub](https://hub.docker.com/r/bentaculum/trackastra-track/tags). Currently, we only provide CPU-based docker images.
+
+Track within a docker container with the following command, filling the ``:
+
+```bash
+docker run -it -v :/data -v :/results bentaculum/trackastra-track: --input_test /data/ --detection_folder "
+```
+
+📄 Show example with Cell Tracking Challenge model:
+
+
+```bash
+wget http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DH-GOWT1.zip
+chmod -R 775 Fluo-N2DH-GOWT1
+docker pull bentaculum/trackastra-track:model.ctc-linking.ilp
+docker run -it -v ./:/data -v ./:/results bentaculum/trackastra-track:model.ctc-linking.ilp --input_test data/Fluo-N2DH-GOWT1/01 --detection_folder TRA
+```
+
+
+
+
+
+ Command Line Interface
+
+After [installation](#installation), simply run in your terminal
+
+```bash
+trackastra track --help
+```
+
+to build a command for tracking directly from images and corresponding instance segmentation masks saved on disk as two series of TIF files.
+
+
+## Usage: Training a model on your own data
To run an example
- clone this repository and got into the scripts directory with `cd trackastra/scripts`.
diff --git a/scripts/TAP_aug_bacteria copy.yaml b/scripts/TAP_aug_bacteria copy.yaml
new file mode 100644
index 0000000..ebe15e9
--- /dev/null
+++ b/scripts/TAP_aug_bacteria copy.yaml
@@ -0,0 +1,97 @@
+# Forked from zih_bacteria.yaml
+name: vanvliet_dinov2_aug
+# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug
+epochs: 750
+warmup_epochs: 5
+window: 4
+attn_dist_mode: v1
+delta_cutoff: 1
+num_encoder_layers: 4
+num_decoder_layers: 4
+causal_norm: none
+d_model: 128
+dropout: 0.05
+pos_embed_per_dim: 32
+lr: 0.0001
+train_samples: 32000
+max_tokens: 2048
+batch_size: 64
+detection_folders:
+- TRA
+crop_size:
+- 320
+- 320
+attn_positional_bias: rope
+ndim: 2
+# Features config
+features: pretrained_feats_aug
+# features: pretrained_feats
+# features: wrfeat
+#### Additional parameters for pretrained_feats
+augment: 3
+pretrained_n_augs: 25
+# pretrained_feats_model: facebook/sam2.1-hiera-base-plus
+# pretrained_feats_model: facebookresearch/co-tracker
+# pretrained_feats_model: facebook/dinov2-base
+pretrained_feats_model: weigertlab/tarrow
+pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2
+# pretrained_feats_model: debug/encoded_labels
+reduced_pretrained_feat_dim: 128
+pretrained_feats_mode: mean_patches_exact
+# pretrained_feats_mode: nearest_patch
+pretrained_feats_additional_props: regionprops_small
+# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding
+rotate_features: true
+# disable_xy_coords: true
+# disable_all_coords: True
+cachedir: /backup/achard/cache
+outdir: /home/achard/trastra_v2
+# weight_decay: 0.01
+#### Logger
+logger: wandb
+wandb_project: "trackastra_v2"
+# Paths config
+cache: true
+compress: true
+distributed: false
+###
+
+input_train:
+- /home/achard/CTC_DATA/vanvliet/cib/140409-03 #
+- /home/achard/CTC_DATA/vanvliet/cib/140415-08
+- /home/achard/CTC_DATA/vanvliet/recA/151027-05
+- /home/achard/CTC_DATA/vanvliet/recA/151027-06
+# - /home/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying
+- /home/achard/CTC_DATA/vanvliet/recA/151028-01
+- /home/achard/CTC_DATA/vanvliet/recA/151029-05
+- /home/achard/CTC_DATA/vanvliet/recA/151029-11
+- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12
+- /home/achard/CTC_DATA/vanvliet/metA/150317-07
+- /home/achard/CTC_DATA/vanvliet/metA/150318-06
+- /home/achard/CTC_DATA/vanvliet/metA/150331-12
+- /home/achard/CTC_DATA/vanvliet/metA/151222-10
+- /home/achard/CTC_DATA/vanvliet/metA/151222-11
+- /home/achard/CTC_DATA/vanvliet/pheA/150324-03
+- /home/achard/CTC_DATA/vanvliet/pheA/150324-05
+- /home/achard/CTC_DATA/vanvliet/pheA/150325-04
+- /home/achard/CTC_DATA/vanvliet/pheA/160112-04
+- /home/achard/CTC_DATA/vanvliet/trpL/150303-01
+- /home/achard/CTC_DATA/vanvliet/trpL/150303-08
+
+input_val:
+- /home/achard/CTC_DATA/vanvliet/trpL/150428-08
+- /home/achard/CTC_DATA/vanvliet/trpL/151021-11
+
+input_test:
+# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise
+- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok
+- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok
+- /home/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok
+# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty
+- /home/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving!
+- /home/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set.
diff --git a/scripts/cotrack3_aug_bacteria.yaml b/scripts/cotrack3_aug_bacteria.yaml
new file mode 100644
index 0000000..8e8d979
--- /dev/null
+++ b/scripts/cotrack3_aug_bacteria.yaml
@@ -0,0 +1,96 @@
+# Forked from zih_bacteria.yaml
+name: vanvliet_cotracker3_aug
+# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug
+epochs: 750
+warmup_epochs: 5
+window: 4
+attn_dist_mode: v1
+delta_cutoff: 1
+num_encoder_layers: 4
+num_decoder_layers: 4
+causal_norm: none
+d_model: 128
+dropout: 0.05
+pos_embed_per_dim: 32
+lr: 0.0001
+train_samples: 320005
+max_tokens: 2048
+batch_size: 64
+detection_folders:
+- TRA
+crop_size:
+- 320
+- 320
+attn_positional_bias: rope
+ndim: 2
+# Features config
+features: pretrained_feats_aug
+# features: pretrained_feats
+# features: wrfeat
+#### Additional parameters for pretrained_feats
+augment: 3
+pretrained_n_augs: 25
+# pretrained_feats_model: facebook/sam2.1-hiera-base-plus
+pretrained_feats_model: facebookresearch/co-tracker
+# pretrained_feats_model: weigertlab/tarrow
+# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2
+# pretrained_feats_model: debug/encoded_labels
+reduced_pretrained_feat_dim: 128
+pretrained_feats_mode: mean_patches_exact
+# pretrained_feats_mode: nearest_patch
+pretrained_feats_additional_props: regionprops_small
+# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding
+rotate_features: true
+# disable_xy_coords: true
+# disable_all_coords: True
+cachedir: /home/achard/cache
+outdir: /home/achard/trastra_v2
+# weight_decay: 0.01
+#### Logger
+logger: wandb
+wandb_project: "trackastra_v2"
+# Paths config
+cache: true
+compress: true
+distributed: false
+###
+
+input_train:
+- /home/achard/CTC_DATA/vanvliet/cib/140409-03 #
+- /home/achard/CTC_DATA/vanvliet/cib/140415-08
+- /home/achard/CTC_DATA/vanvliet/recA/151027-05
+- /home/achard/CTC_DATA/vanvliet/recA/151027-06
+# - /home/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying
+- /home/achard/CTC_DATA/vanvliet/recA/151028-01
+- /home/achard/CTC_DATA/vanvliet/recA/151029-05
+- /home/achard/CTC_DATA/vanvliet/recA/151029-11
+- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12
+- /home/achard/CTC_DATA/vanvliet/metA/150317-07
+- /home/achard/CTC_DATA/vanvliet/metA/150318-06
+- /home/achard/CTC_DATA/vanvliet/metA/150331-12
+- /home/achard/CTC_DATA/vanvliet/metA/151222-10
+- /home/achard/CTC_DATA/vanvliet/metA/151222-11
+- /home/achard/CTC_DATA/vanvliet/pheA/150324-03
+- /home/achard/CTC_DATA/vanvliet/pheA/150324-05
+- /home/achard/CTC_DATA/vanvliet/pheA/150325-04
+- /home/achard/CTC_DATA/vanvliet/pheA/160112-04
+- /home/achard/CTC_DATA/vanvliet/trpL/150303-01
+- /home/achard/CTC_DATA/vanvliet/trpL/150303-08
+
+input_val:
+- /home/achard/CTC_DATA/vanvliet/trpL/150428-08
+- /home/achard/CTC_DATA/vanvliet/trpL/151021-11
+
+input_test:
+# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise
+- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok
+- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok
+- /home/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok
+# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty
+- /home/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving!
+- /home/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set.
diff --git a/scripts/deepcell/SAM_deepcell.yaml b/scripts/deepcell/SAM_deepcell.yaml
new file mode 100644
index 0000000..60d2c66
--- /dev/null
+++ b/scripts/deepcell/SAM_deepcell.yaml
@@ -0,0 +1,100 @@
+name: deepcell_cotrack_aug_no_drpt
+# model: /home/achard/tratra_runs_eval/2025-05-28_10-04-50_vanvliet_sam21_aug
+epochs: 1000
+warmup_epochs: 5
+window: 4
+attn_dist_mode: v1
+delta_cutoff: 1
+num_encoder_layers: 4
+num_decoder_layers: 4
+causal_norm: none
+d_model: 128
+# dropout: 0.05
+pos_embed_per_dim: 32
+lr: 0.0001
+train_samples: 32000
+max_tokens: 2048
+batch_size: 128
+detection_folders:
+- TRA
+crop_size:
+- 320
+- 320
+attn_positional_bias: rope
+ndim: 2
+# Features config
+features: pretrained_feats_aug
+# features: pretrained_feats
+# features: wrfeat
+#### Additional parameters for pretrained_feats
+augment: 3
+pretrained_n_augs: 25
+pretrained_feats_model: facebook/sam-vit-base
+reduced_pretrained_feat_dim: 128
+# pretrained_feats_model: weigertlab/tarrow
+# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2
+# pretrained_feats_model: debug/encoded_labels
+pretrained_feats_mode: mean_patches_exact
+# pretrained_feats_mode: nearest_patch
+pretrained_feats_additional_props: regionprops_small
+# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding
+# rotate_features: true
+# disable_xy_coords: true
+# disable_all_coords: True
+cachedir: /backup/achard/cache
+outdir: /home/achard/trastra_runs_v2
+#### Logger
+logger: wandb
+wandb_project: "trackstra_v2_deepcell"
+# Paths config
+cache: True
+compress: True
+distributed: True
+
+input_train:
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/00
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/01
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/02
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/03
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/04
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/05
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/06
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/07
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/08
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/09
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/10
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/11
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/12
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/13
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/14
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/15
+# - /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/16
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/16_corrected
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/17
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/18
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/19
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/20
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/23
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/25
+
+input_val:
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/21
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/22
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/24
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/44
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/45
+
+
+input_test:
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/00
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/01
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/02
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/03
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/04
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/05
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/06
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/07
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/08
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/09
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/10
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/11
diff --git a/scripts/deepcell/cotracker3_deepcell.yaml b/scripts/deepcell/cotracker3_deepcell.yaml
new file mode 100644
index 0000000..d8dfd53
--- /dev/null
+++ b/scripts/deepcell/cotracker3_deepcell.yaml
@@ -0,0 +1,100 @@
+name: deepcell_cotrack_aug_no_drpt
+# model: /home/achard/tratra_runs_eval/2025-05-28_10-04-50_vanvliet_sam21_aug
+epochs: 1000
+warmup_epochs: 5
+window: 4
+attn_dist_mode: v1
+delta_cutoff: 1
+num_encoder_layers: 4
+num_decoder_layers: 4
+causal_norm: none
+d_model: 128
+dropout: 0.05
+pos_embed_per_dim: 32
+lr: 0.0001
+train_samples: 32000
+max_tokens: 2048
+batch_size: 128
+detection_folders:
+- TRA
+crop_size:
+- 320
+- 320
+attn_positional_bias: rope
+ndim: 2
+# Features config
+features: pretrained_feats_aug
+# features: pretrained_feats
+# features: wrfeat
+#### Additional parameters for pretrained_feats
+augment: 3
+pretrained_n_augs: 25
+pretrained_feats_model: facebookresearch/co-tracker
+reduced_pretrained_feat_dim: 128
+# pretrained_feats_model: weigertlab/tarrow
+# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2
+# pretrained_feats_model: debug/encoded_labels
+pretrained_feats_mode: mean_patches_exact
+# pretrained_feats_mode: nearest_patch
+pretrained_feats_additional_props: regionprops_small
+# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding
+# rotate_features: true
+# disable_xy_coords: true
+# disable_all_coords: True
+cachedir: /backup/achard/cache
+outdir: /home/achard/trastra_runs_v2
+#### Logger
+logger: wandb
+wandb_project: "trackstra_v2_deepcell"
+# Paths config
+cache: True
+compress: True
+distributed: True
+
+input_train:
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/00
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/01
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/02
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/03
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/04
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/05
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/06
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/07
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/08
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/09
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/10
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/11
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/12
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/13
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/14
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/15
+# - /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/16
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/16_corrected
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/17
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/18
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/19
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/20
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/23
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/25
+
+input_val:
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/21
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/22
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/24
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/44
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/45
+
+
+input_test:
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/00
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/01
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/02
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/03
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/04
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/05
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/06
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/07
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/08
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/09
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/10
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/11
diff --git a/scripts/deepcell/sam21_deepcell.yaml b/scripts/deepcell/sam21_deepcell.yaml
new file mode 100644
index 0000000..c855032
--- /dev/null
+++ b/scripts/deepcell/sam21_deepcell.yaml
@@ -0,0 +1,185 @@
+name: deepcell_sam21_aug
+epochs: 1000
+warmup_epochs: 5
+window: 4
+attn_dist_mode: v1
+num_encoder_layers: 4
+num_decoder_layers: 4
+causal_norm: none
+d_model: 256
+dropout: 0.05
+pos_embed_per_dim: 32
+lr: 0.0001
+train_samples: 32000
+max_tokens: 2048
+batch_size: 128
+detection_folders:
+- TRA
+crop_size:
+- 256
+- 256
+attn_positional_bias: rope
+ndim: 2
+# Features config
+features: pretrained_feats_aug
+#### Additional parameters for pretrained_feats
+augment: 3
+pretrained_n_augs: 25
+pretrained_feats_model: facebook/sam2.1-hiera-base-plus
+reduced_pretrained_feat_dim: 128
+pretrained_feats_mode: mean_patches_exact
+pretrained_feats_additional_props: regionprops_small
+rotate_features: true
+cachedir: /backup/achard/cache
+outdir: /home/achard/trastra_v3_deepcell
+# weight_decay: 0.01
+#### Logger
+logger: wandb
+wandb_project: "trackastra_v3_deepcell"
+# Paths config
+
+### DEEPCELL SPECIFIC
+delta_cutoff: 2
+spatial_pos_cutoff: 96
+div_upweight: 2
+###
+cache: true
+compress: true
+distributed: false
+num_workers: 24
+
+input_train:
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/00
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/01
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/02
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/03
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/04
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/05
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/06
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/07
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/08
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/09
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/10
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/11
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/12
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/13
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/14
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/15
+# - /backup/achard/CTC_DATA/deepcell/deepcell_full/train/16
+- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/16_corrected
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/17
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/18
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/19
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/20
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/21
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/22
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/23
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/24
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/25
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/26
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/27
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/28
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/29
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/30
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/31
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/32
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/33
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/34
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/35
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/36
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/37
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/38
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/39
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/40
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/41
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/42
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/43
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/44
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/45
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/46
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/47
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/48
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/49
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/50
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/51
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/52
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/53
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/54
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/55
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/56
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/57
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/58
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/59
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/60
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/61
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/62
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/63
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/64
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/65
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/66
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/67
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/68
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/69
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/70
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/71
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/72
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/73
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/74
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/75
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/76
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/77
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/78
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/79
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/80
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/81
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/82
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/83
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/84
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/85
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/86
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/87
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/88
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/89
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/90
+input_val:
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/00
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/01
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/02
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/03
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/04
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/05
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/06
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/07
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/08
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/09
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/10
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/11
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/12
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/13
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/14
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/15
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/16
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/17
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/18
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/19
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/20
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/21
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/22
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/23
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/24
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/25
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/26
+
+input_test:
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/00
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/01
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/02
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/03
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/04
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/05
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/06
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/07
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/08
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/09
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/10
+- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/11
\ No newline at end of file
diff --git a/scripts/dinov2_aug_bacteria.yaml b/scripts/dinov2_aug_bacteria.yaml
new file mode 100644
index 0000000..150434d
--- /dev/null
+++ b/scripts/dinov2_aug_bacteria.yaml
@@ -0,0 +1,97 @@
+# Forked from zih_bacteria.yaml
+name: vanvliet_dinov2_aug
+# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug
+epochs: 750
+warmup_epochs: 5
+window: 4
+attn_dist_mode: v1
+delta_cutoff: 1
+num_encoder_layers: 4
+num_decoder_layers: 4
+causal_norm: none
+d_model: 128
+dropout: 0.05
+pos_embed_per_dim: 32
+lr: 0.0001
+train_samples: 32000
+max_tokens: 2048
+batch_size: 64
+detection_folders:
+- TRA
+crop_size:
+- 320
+- 320
+attn_positional_bias: rope
+ndim: 2
+# Features config
+features: pretrained_feats_aug
+# features: pretrained_feats
+# features: wrfeat
+#### Additional parameters for pretrained_feats
+augment: 3
+pretrained_n_augs: 25
+# pretrained_feats_model: facebook/sam2.1-hiera-base-plus
+# pretrained_feats_model: facebookresearch/co-tracker
+pretrained_feats_model: facebook/dinov2-base
+# pretrained_feats_model: weigertlab/tarrow
+# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2
+# pretrained_feats_model: debug/encoded_labels
+reduced_pretrained_feat_dim: 128
+pretrained_feats_mode: mean_patches_exact
+# pretrained_feats_mode: nearest_patch
+pretrained_feats_additional_props: regionprops_small
+# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding
+rotate_features: true
+# disable_xy_coords: true
+# disable_all_coords: True
+cachedir: /backup/achard/cache
+outdir: /home/achard/trastra_v2
+# weight_decay: 0.01
+#### Logger
+logger: wandb
+wandb_project: "trackastra_v2"
+# Paths config
+cache: true
+compress: true
+distributed: false
+###
+
+input_train:
+- /home/achard/CTC_DATA/vanvliet/cib/140409-03 #
+- /home/achard/CTC_DATA/vanvliet/cib/140415-08
+- /home/achard/CTC_DATA/vanvliet/recA/151027-05
+- /home/achard/CTC_DATA/vanvliet/recA/151027-06
+# - /home/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying
+- /home/achard/CTC_DATA/vanvliet/recA/151028-01
+- /home/achard/CTC_DATA/vanvliet/recA/151029-05
+- /home/achard/CTC_DATA/vanvliet/recA/151029-11
+- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12
+- /home/achard/CTC_DATA/vanvliet/metA/150317-07
+- /home/achard/CTC_DATA/vanvliet/metA/150318-06
+- /home/achard/CTC_DATA/vanvliet/metA/150331-12
+- /home/achard/CTC_DATA/vanvliet/metA/151222-10
+- /home/achard/CTC_DATA/vanvliet/metA/151222-11
+- /home/achard/CTC_DATA/vanvliet/pheA/150324-03
+- /home/achard/CTC_DATA/vanvliet/pheA/150324-05
+- /home/achard/CTC_DATA/vanvliet/pheA/150325-04
+- /home/achard/CTC_DATA/vanvliet/pheA/160112-04
+- /home/achard/CTC_DATA/vanvliet/trpL/150303-01
+- /home/achard/CTC_DATA/vanvliet/trpL/150303-08
+
+input_val:
+- /home/achard/CTC_DATA/vanvliet/trpL/150428-08
+- /home/achard/CTC_DATA/vanvliet/trpL/151021-11
+
+input_test:
+# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise
+- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok
+- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok
+- /home/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok
+# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty
+- /home/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving!
+- /home/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set.
diff --git a/scripts/example_config.yaml b/scripts/example_config.yaml
index 966480f..547491b 100644
--- a/scripts/example_config.yaml
+++ b/scripts/example_config.yaml
@@ -1,20 +1,97 @@
-batch_size: 8
-crop_size:
-- 256
-- 256
+# Forked from zih_bacteria.yaml
+name: vanvliet_sam21_quiet_softmax_aug
+# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug
+epochs: 1000
+warmup_epochs: 5
+window: 4
+attn_dist_mode: v1
+delta_cutoff: 1
+num_encoder_layers: 4
+num_decoder_layers: 4
+# causal_norm: none
+causal_norm: quiet_softmax
+d_model: 128
+dropout: 0.0
+pos_embed_per_dim: 32
+lr: 0.0001
+train_samples: 32000
+max_tokens: 2048
+batch_size: 60
detection_folders:
- TRA
-dropout: 0.01
-example_images: False # Slow
+crop_size:
+- 320
+- 320
+attn_positional_bias: rope
+ndim: 2
+# Features config
+features: pretrained_feats_aug
+# features: pretrained_feats
+# features: wrfeat
+#### Additional parameters for pretrained_feats
+augment: 3
+pretrained_n_augs: 25
+pretrained_feats_model: facebook/sam2.1-hiera-base-plus
+# pretrained_feats_model: facebookresearch/co-tracker
+# pretrained_feats_model: weigertlab/tarrow
+# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2
+# pretrained_feats_model: debug/encoded_labels
+reduced_pretrained_feat_dim: 128
+pretrained_feats_mode: mean_patches_exact
+# pretrained_feats_mode: nearest_patch
+pretrained_feats_additional_props: regionprops_small
+# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding
+rotate_features: true
+# disable_xy_coords: true
+# disable_all_coords: True
+cachedir: /home/achard/cache
+outdir: /home/achard/trastra_v2
+# weight_decay: 0.01
+#### Logger
+logger: wandb
+wandb_project: "trackastra_v2"
+# Paths config
+cache: true
+compress: true
+distributed: false
+###
+
input_train:
-- data/ctc/Fluo-N2DL-HeLa/02
+- /backup/achard/CTC_DATA/vanvliet/cib/140409-03 #
+- /backup/achard/CTC_DATA/vanvliet/cib/140415-08
+- /backup/achard/CTC_DATA/vanvliet/recA/151027-05
+- /backup/achard/CTC_DATA/vanvliet/recA/151027-06
+# - /backup/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying
+- /backup/achard/CTC_DATA/vanvliet/recA/151028-01
+- /backup/achard/CTC_DATA/vanvliet/recA/151029-05
+- /backup/achard/CTC_DATA/vanvliet/recA/151029-11
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12
+- /backup/achard/CTC_DATA/vanvliet/metA/150317-07
+- /backup/achard/CTC_DATA/vanvliet/metA/150318-06
+- /backup/achard/CTC_DATA/vanvliet/metA/150331-12
+- /backup/achard/CTC_DATA/vanvliet/metA/151222-10
+- /backup/achard/CTC_DATA/vanvliet/metA/151222-11
+- /backup/achard/CTC_DATA/vanvliet/pheA/150324-03
+- /backup/achard/CTC_DATA/vanvliet/pheA/150324-05
+- /backup/achard/CTC_DATA/vanvliet/pheA/150325-04
+- /backup/achard/CTC_DATA/vanvliet/pheA/160112-04
+- /backup/achard/CTC_DATA/vanvliet/trpL/150303-01
+- /backup/achard/CTC_DATA/vanvliet/trpL/150303-08
+
input_val:
-- data/ctc/Fluo-N2DL-HeLa/01
-max_tokens: 2048
-name: example
-ndim: 2
-num_decoder_layers: 5
-num_encoder_layers: 5
-outdir: runs
-distributed: False
-window: 6
+- /backup/achard/CTC_DATA/vanvliet/trpL/150428-08
+- /backup/achard/CTC_DATA/vanvliet/trpL/151021-11
+
+input_test:
+# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok
+- /backup/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok
+# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty
+- /backup/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving!
+- /backup/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set.
diff --git a/scripts/sam21_aug_bacteria.yaml b/scripts/sam21_aug_bacteria.yaml
new file mode 100644
index 0000000..9128751
--- /dev/null
+++ b/scripts/sam21_aug_bacteria.yaml
@@ -0,0 +1,96 @@
+# Forked from zih_bacteria.yaml
+name: vanvliet_sam21_aug
+# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug
+epochs: 1000
+warmup_epochs: 5
+window: 4
+attn_dist_mode: v1
+delta_cutoff: 1
+num_encoder_layers: 4
+num_decoder_layers: 4
+causal_norm: none
+d_model: 256
+dropout: 0.05
+pos_embed_per_dim: 32
+lr: 0.0001
+train_samples: 32000
+max_tokens: 2048
+batch_size: 55
+detection_folders:
+- TRA
+crop_size:
+- 320
+- 320
+attn_positional_bias: rope
+ndim: 2
+# Features config
+features: pretrained_feats_aug
+# features: pretrained_feats
+# features: wrfeat
+#### Additional parameters for pretrained_feats
+augment: 3
+pretrained_n_augs: 25
+pretrained_feats_model: facebook/sam2.1-hiera-base-plus
+# pretrained_feats_model: facebookresearch/co-tracker
+# pretrained_feats_model: weigertlab/tarrow
+# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2
+# pretrained_feats_model: debug/encoded_labels
+reduced_pretrained_feat_dim: 128
+pretrained_feats_mode: mean_patches_exact
+# pretrained_feats_mode: nearest_patch
+pretrained_feats_additional_props: regionprops_small
+# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding
+rotate_features: true
+# disable_xy_coords: true
+# disable_all_coords: True
+cachedir: /backup/achard/cache
+outdir: /home/achard/trastra_v2
+# weight_decay: 0.01
+#### Logger
+logger: wandb
+wandb_project: "trackastra_v2"
+# Paths config
+cache: true
+compress: true
+distributed: false
+###
+
+input_train:
+- /backup/achard/CTC_DATA/vanvliet/cib/140409-03 #
+- /backup/achard/CTC_DATA/vanvliet/cib/140415-08
+- /backup/achard/CTC_DATA/vanvliet/recA/151027-05
+- /backup/achard/CTC_DATA/vanvliet/recA/151027-06
+# - /backup/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying
+- /backup/achard/CTC_DATA/vanvliet/recA/151028-01
+- /backup/achard/CTC_DATA/vanvliet/recA/151029-05
+- /backup/achard/CTC_DATA/vanvliet/recA/151029-11
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12
+- /backup/achard/CTC_DATA/vanvliet/metA/150317-07
+- /backup/achard/CTC_DATA/vanvliet/metA/150318-06
+- /backup/achard/CTC_DATA/vanvliet/metA/150331-12
+- /backup/achard/CTC_DATA/vanvliet/metA/151222-10
+- /backup/achard/CTC_DATA/vanvliet/metA/151222-11
+- /backup/achard/CTC_DATA/vanvliet/pheA/150324-03
+- /backup/achard/CTC_DATA/vanvliet/pheA/150324-05
+- /backup/achard/CTC_DATA/vanvliet/pheA/150325-04
+- /backup/achard/CTC_DATA/vanvliet/pheA/160112-04
+- /backup/achard/CTC_DATA/vanvliet/trpL/150303-01
+- /backup/achard/CTC_DATA/vanvliet/trpL/150303-08
+
+input_val:
+- /backup/achard/CTC_DATA/vanvliet/trpL/150428-08
+- /backup/achard/CTC_DATA/vanvliet/trpL/151021-11
+
+input_test:
+# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok
+- /backup/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok
+# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty
+- /backup/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving!
+- /backup/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set.
diff --git a/scripts/sam_aug_bacteria.yaml b/scripts/sam_aug_bacteria.yaml
new file mode 100644
index 0000000..17d68b7
--- /dev/null
+++ b/scripts/sam_aug_bacteria.yaml
@@ -0,0 +1,96 @@
+# Forked from zih_bacteria.yaml
+name: vanvliet_SAM_aug
+# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug
+epochs: 750
+warmup_epochs: 5
+window: 4
+attn_dist_mode: v1
+delta_cutoff: 1
+num_encoder_layers: 4
+num_decoder_layers: 4
+causal_norm: none
+d_model: 128
+dropout: 0.05
+pos_embed_per_dim: 32
+lr: 0.0001
+train_samples: 32000
+max_tokens: 2048
+batch_size: 64
+detection_folders:
+- TRA
+crop_size:
+- 320
+- 320
+attn_positional_bias: rope
+ndim: 2
+# Features config
+features: pretrained_feats_aug
+# features: pretrained_feats
+# features: wrfeat
+#### Additional parameters for pretrained_feats
+augment: 3
+pretrained_n_augs: 25
+pretrained_feats_model: facebook/sam-vit-base
+# pretrained_feats_model: facebookresearch/co-tracker
+# pretrained_feats_model: weigertlab/tarrow
+# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2
+# pretrained_feats_model: debug/encoded_labels
+reduced_pretrained_feat_dim: 128
+pretrained_feats_mode: mean_patches_exact
+# pretrained_feats_mode: nearest_patch
+pretrained_feats_additional_props: regionprops_small
+# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding
+rotate_features: true
+# disable_xy_coords: true
+# disable_all_coords: True
+cachedir: /home/achard/cache
+outdir: /home/achard/trastra_v2
+# weight_decay: 0.01
+#### Logger
+logger: wandb
+wandb_project: "trackastra_v2"
+# Paths config
+cache: true
+compress: true
+distributed: false
+###
+
+input_train:
+- /home/achard/CTC_DATA/vanvliet/cib/140409-03 #
+- /home/achard/CTC_DATA/vanvliet/cib/140415-08
+- /home/achard/CTC_DATA/vanvliet/recA/151027-05
+- /home/achard/CTC_DATA/vanvliet/recA/151027-06
+# - /home/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying
+- /home/achard/CTC_DATA/vanvliet/recA/151028-01
+- /home/achard/CTC_DATA/vanvliet/recA/151029-05
+- /home/achard/CTC_DATA/vanvliet/recA/151029-11
+- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12
+- /home/achard/CTC_DATA/vanvliet/metA/150317-07
+- /home/achard/CTC_DATA/vanvliet/metA/150318-06
+- /home/achard/CTC_DATA/vanvliet/metA/150331-12
+- /home/achard/CTC_DATA/vanvliet/metA/151222-10
+- /home/achard/CTC_DATA/vanvliet/metA/151222-11
+- /home/achard/CTC_DATA/vanvliet/pheA/150324-03
+- /home/achard/CTC_DATA/vanvliet/pheA/150324-05
+- /home/achard/CTC_DATA/vanvliet/pheA/150325-04
+- /home/achard/CTC_DATA/vanvliet/pheA/160112-04
+- /home/achard/CTC_DATA/vanvliet/trpL/150303-01
+- /home/achard/CTC_DATA/vanvliet/trpL/150303-08
+
+input_val:
+- /home/achard/CTC_DATA/vanvliet/trpL/150428-08
+- /home/achard/CTC_DATA/vanvliet/trpL/151021-11
+
+input_test:
+# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise
+- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok
+- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok
+- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok
+- /home/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok
+# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty
+- /home/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving!
+- /home/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set.
diff --git a/scripts/train.py b/scripts/train.py
index 2b74698..a9f2de6 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -16,6 +16,7 @@
import configargparse
import git
+import humanize
import lightning as pl
import numpy as np
import psutil
@@ -25,7 +26,6 @@
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch.profilers import PyTorchProfiler
from lightning.pytorch.utilities.rank_zero import rank_zero_only
-from numerize import numerize
from skimage.morphology import binary_dilation, disk
from torch.optim.lr_scheduler import LRScheduler
from torchvision.utils import make_grid
@@ -36,6 +36,12 @@
CTCData,
collate_sequence_padding,
)
+from trackastra.data.pretrained_features import (
+ AVAILABLE_PRETRAINED_BACKBONES,
+ PretrainedFeatsExtractionMode,
+ PretrainedFeatureExtractorConfig,
+)
+from trackastra.data.wrfeat import _PROPERTIES, DEFAULT_PROPERTIES, WRFeatures
from trackastra.model import TrackingTransformer
from trackastra.utils import (
blockwise_causal_norm,
@@ -48,12 +54,13 @@
str2bool,
)
-logging.basicConfig(level=logging.INFO)
+# logging.basicConfig(level=logging.INFO)
+logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
warnings.simplefilter(action="ignore", category=FutureWarning)
-device = torch.device("cuda" if torch.cuda.is_available() else "mps")
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.seterr(all="ignore")
@@ -83,7 +90,6 @@ def __init__(
max_epochs,
cosine_final: float = 0.001,
last_epoch=-1,
- verbose=False,
):
"""Use cosine_final to switch on/off the cosine annealing.
@@ -93,7 +99,7 @@ def __init__(
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.cosine_final = cosine_final
- super().__init__(optimizer, last_epoch, verbose)
+ super().__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
@@ -182,6 +188,8 @@ def __init__(
tracking_frequency: int = -1, # log TRA metrics every that epochs
batch_val_tb_idx: int = 0, # the batch index to visualize in tensorboard
div_upweight: float = 20,
+ # per_param_clipping: bool = False,
+ weight_decay: float = 0.01,
):
super().__init__()
@@ -196,30 +204,60 @@ def __init__(
self.batch_val_tb = None
self.lr = learning_rate
+ self.weight_decay = weight_decay
self.tracking_frequency = tracking_frequency
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.div_upweight = div_upweight
-
+ # self.per_param_clipping = per_param_clipping
+
def _common_step(self, batch, eps=torch.finfo(torch.float32).eps):
+ # torch.autograd.set_detect_anomaly(True)
feats = batch["features"]
+ try:
+ pretrained_feats = batch["pretrained_features"]
+ except KeyError:
+ pretrained_feats = None
coords = batch["coords"]
A = batch["assoc_matrix"]
timepoints = batch["timepoints"]
padding_mask = batch["padding_mask"]
padding_mask = padding_mask.bool()
+
+ if feats is not None:
+ if torch.any(torch.isnan(feats)):
+ nan_dims = torch.any(torch.isnan(feats), dim=-1)
+ raise ValueError("NaN in features in dimensions: ", nan_dims)
+ if pretrained_feats is not None:
+ if torch.any(torch.isnan(pretrained_feats)):
+ nan_dims = torch.any(torch.isnan(pretrained_feats), dim=-1)
+ raise ValueError("NaN in pretrained features in dimensions: ", nan_dims)
+ if torch.any(torch.isnan(coords)):
+ raise ValueError("NaN in coords")
+
+ A_pred = self.model(coords, feats, pretrained_feats, padding_mask=padding_mask)
+
+ if self.model.norms: # if dict is not empty, log each entry to wandb
+ for key, value in self.model.norms.items():
+ # check wandb runner is initialized
+ self.log_dict(
+ {f"norms/{key}": value}, on_step=True, on_epoch=False, sync_dist=True
+ )
- A_pred = self.model(coords, feats, padding_mask=padding_mask)
# remove inf values that might happen due to float16 numerics
A_pred.clamp_(torch.finfo(torch.float16).min, torch.finfo(torch.float16).max)
+ # above call might interfere with backward as it is an inplace operation
+ # A_pred = A_pred.clamp(torch.finfo(torch.float16).min, torch.finfo(torch.float16).max)
mask_invalid = torch.logical_or(
padding_mask.unsqueeze(1), padding_mask.unsqueeze(2)
)
A_pred[mask_invalid] = 0
+ # above call might interfere with backward as it is an inplace operation in "linear" causal norm
+ # A_pred = A_pred.masked_fill(mask_invalid, 0)
loss = self.criterion(A_pred, A)
-
+
if self.causal_norm != "none":
# TODO speedup: I could softmax only the part of the matrix (upper triangular) that is not masked out
A_pred_soft = torch.stack(
@@ -230,10 +268,11 @@ def _common_step(self, batch, eps=torch.finfo(torch.float32).eps):
for _A, _t, _m in zip(A_pred, timepoints, mask_invalid)
]
)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch.amp.autocast(enabled=False, device_type=str(self.device)):
if len(A) > 0:
# debug
if torch.any(torch.isnan(A_pred_soft)):
+
print(A_pred)
print(
"AAAA pred",
@@ -252,9 +291,17 @@ def _common_step(self, batch, eps=torch.finfo(torch.float32).eps):
timepoints=timepoints.detach().cpu().numpy(),
)
+ if A_pred_soft.dtype != A.dtype:
+ logger.warning(
+ "A_pred_soft has different dtype than A, casting to A.dtype"
+ )
+ A_pred_soft = A_pred_soft.to(A.dtype)
# Keep the non-softmaxed loss for numerical stability
loss = 0.01 * loss + self.criterion_softmax(A_pred_soft, A)
+ if torch.any(torch.isnan(loss)):
+ raise ValueError("NaN after loss summing")
+
# Reweighting does not need gradients
with torch.no_grad():
block_sum1 = torch.stack(
@@ -287,6 +334,9 @@ def _common_step(self, batch, eps=torch.finfo(torch.float32).eps):
mask.sum(dim=(1, 2), keepdim=True) + eps
)
loss_per_sample = loss_normalized.sum(dim=(1, 2))
+
+ if torch.any(torch.isnan(loss_per_sample)):
+ raise ValueError("NaN in loss_per_sample after reduction")
# Hack: weight larger samples a little more...
prefactor = torch.pow(mask.sum(dim=(1, 2)), 0.2)
@@ -313,7 +363,7 @@ def checkpoint_path(self, logdir):
return None
def configure_optimizers(self):
- optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-5)
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
return dict(
optimizer=optimizer,
lr_scheduler=WarmupCosineLRScheduler(
@@ -321,12 +371,26 @@ def configure_optimizers(self):
),
)
+ def on_before_optimizer_step(self, optimizer):
+ # self.trainer.precision_plugin.scaler.unscale_(optimizer)
+ # from torch.nn.utils import clip_grad_norm_
+ # if self.per_param_clipping:
+ # for param in self.model.parameters():
+ # if param.grad is not None:
+ # clip_grad_norm_(param, max_norm=1.0)
+ # Compute the 2-norm for each layer
+ # If using mixed precision, the gradients are already unscaled here
+ from lightning.pytorch.utilities import grad_norm
+ norms = grad_norm(self.model, norm_type=2)
+ self.log_dict(norms)
+
def training_step(self, batch, batch_idx):
out = self._common_step(batch)
loss = out["loss"]
if torch.isnan(loss):
- print("NaN loss, skipping")
- return None
+ # print("NaN loss, skipping")
+ # return None
+ raise ValueError("NaN loss")
self.log(
"train_loss",
@@ -335,6 +399,7 @@ def training_step(self, batch, batch_idx):
on_step=False,
on_epoch=True,
sync_dist=True,
+ batch_size=batch["coords"].shape[0],
)
# self.train_loss.append(loss)
@@ -368,7 +433,7 @@ def validation_step(self, batch, batch_idx):
loss = out["loss"]
if torch.isnan(loss):
print("NaN loss, skipping")
- return None
+ raise ValueError("NaN loss")
self.log(
"val_loss",
@@ -377,6 +442,7 @@ def validation_step(self, batch, batch_idx):
on_step=False,
on_epoch=True,
sync_dist=True,
+ batch_size=batch["coords"].shape[0],
)
# self.val_loss.append(loss)
@@ -727,6 +793,12 @@ def _init_wandb(project, name, config):
_ = wandb.init(project=project, name=name, config=config)
+@rank_zero_only
+def create_wandb_logger(run_name, wandb_project):
+ wandb_logger = WandbLogger(name=run_name, project=wandb_project)
+ return wandb_logger
+
+
def train(args):
args.seed = seed(args.seed)
if args.model is None:
@@ -766,6 +838,8 @@ def train(args):
train_logger = TensorBoardLogger(logdir, name="tb")
elif args.logger == "wandb":
train_logger = WandbLogger(name=run_name, project=args.wandb_project)
+ # train_logger = create_wandb_logger(run_name, args.wandb_project)
+ # train_logger.log_hyperparams(training_args)
# init here to get an alert on job failure even before training
_init_wandb(project=args.wandb_project, name=run_name, config=vars(args))
@@ -782,6 +856,42 @@ def train(args):
raise ValueError(
f'Logdir {logdir} exists, set "--resume t" if you want to overwrite'
)
+
+ pretrained_config = None
+ if args.features == "pretrained_feats" or args.features == "pretrained_feats_aug":
+ if args.pretrained_feats_model is None:
+ raise ValueError(
+ "Pretrained model must be defined if pretrained features are in use."
+ f"Available models: {AVAILABLE_PRETRAINED_BACKBONES.keys()}"
+ )
+ if args.pretrained_feats_model not in AVAILABLE_PRETRAINED_BACKBONES:
+ raise ValueError(
+ f"Unknown pretrained model {args.pretrained_feats_model}, available: {AVAILABLE_PRETRAINED_BACKBONES.keys()}"
+ )
+ if args.pretrained_feats_mode is None:
+ raise ValueError(
+ "Pretrained mode must be defined if pretrained features are in use."
+ )
+ if args.features == "pretrained_feats_aug" and args.pretrained_n_augs is None:
+ raise ValueError(
+ "Number of augmentated copies must be defined if using augmented pretrained features."
+ )
+ emb_save_path = None if args.cachedir is None else Path(args.cachedir).resolve()
+ if not emb_save_path.exists():
+ emb_save_path.mkdir(parents=False, exist_ok=True)
+ # pca_save_path = (
+ # Path(logdir) / "pca" if args.pretrained_feats_pca_ncomp else None
+ # )
+
+ pretrained_config = PretrainedFeatureExtractorConfig(
+ model_name=args.pretrained_feats_model,
+ mode=args.pretrained_feats_mode,
+ save_path=emb_save_path,
+ additional_features=args.pretrained_feats_additional_props,
+ model_path=args.pretrained_model_path,
+ # pca_components=args.pretrained_feats_pca_ncomp,
+ # pca_preprocessor_path=pca_save_path,
+ )
n_gpus = torch.cuda.device_count() if args.distributed else 1
if args.preallocate:
@@ -801,10 +911,15 @@ def train(args):
sanity_dist=args.sanity_dist,
crop_size=args.crop_size,
compress=args.compress,
+ pretrained_backbone_config=pretrained_config,
+ pretrained_n_augmentations=args.pretrained_n_augs,
+ rotate_features=args.rotate_features,
)
dummy_model = TrackingTransformer(
coord_dim=dummy_data.ndim,
feat_dim=dummy_data.feat_dim,
+ pretrained_feat_dim=dummy_data.pretrained_feat_dim,
+ reduced_pretrained_feat_dim=args.reduced_pretrained_feat_dim,
d_model=args.d_model,
pos_embed_per_dim=args.pos_embed_per_dim,
feat_embed_per_dim=args.feat_embed_per_dim,
@@ -817,6 +932,8 @@ def train(args):
attn_positional_bias_n_spatial=args.attn_positional_bias_n_spatial,
attn_dist_mode=args.attn_dist_mode,
causal_norm=args.causal_norm,
+ disable_xy_coords=args.disable_xy_coords,
+ disable_all_coords=args.disable_all_coords,
)
dummy_model_lightning = WrappedLightningModule(
@@ -829,6 +946,8 @@ def train(args):
tracking_frequency=args.tracking_frequency,
batch_val_tb_idx=0,
div_upweight=args.div_upweight,
+ # per_param_clipping=args.clip_grad_per_param,
+ weight_decay=args.weight_decay,
)
dummy_model_lightning.to(device)
preallocate_memory(
@@ -852,7 +971,7 @@ def train(args):
if args.only_prechecks:
return locals()
-
+
dataset_kwargs = dict(
ndim=args.ndim,
detection_folders=args.detection_folders,
@@ -864,6 +983,9 @@ def train(args):
sanity_dist=args.sanity_dist,
crop_size=args.crop_size,
compress=args.compress,
+ pretrained_backbone_config=pretrained_config,
+ pretrained_n_augmentations=args.pretrained_n_augs,
+ rotate_features=args.rotate_features,
)
sampler_kwargs = dict(
batch_size=args.batch_size,
@@ -879,12 +1001,12 @@ def train(args):
pin_memory=True,
collate_fn=collate_sequence_padding,
)
-
+
# Sampler gets wrapped with distributed sampler, which cannot sample with replacement
datamodule = BalancedDataModule(
input_train=args.input_train,
input_val=args.input_val,
- cachedir=args.cachedir,
+ cachedir=args.cachedir if args.cache else None,
augment=args.augment,
distributed=args.distributed,
dataset_kwargs=dataset_kwargs,
@@ -905,15 +1027,15 @@ def train(args):
)
callbacks.append(pl.pytorch.callbacks.Timer(interval="epoch"))
- # Mostly for stopping broken runs
- callbacks.append(
- pl.pytorch.callbacks.EarlyStopping(
- monitor="val_loss",
- patience=args.epochs // 6,
- mode="min",
- verbose=True,
- )
- )
+ # # Mostly for stopping broken runs
+ # callbacks.append(
+ # pl.pytorch.callbacks.EarlyStopping(
+ # monitor="val_loss",
+ # patience=args.epochs // 6,
+ # mode="min",
+ # verbose=True,
+ # )
+ # )
if args.example_images:
callbacks.append(ExampleImages())
@@ -932,13 +1054,23 @@ def train(args):
else:
model = TrackingTransformer.from_folder(fpath, args=args)
else:
- feat_dim = 0 if args.features == "none" else 7 if args.ndim == 2 else 12
+ # feat_dim = 0 if args.features == "none" else 7 if args.ndim == 2 else 12
+ if args.features == "pretrained_feats" or args.features == "pretrained_feats_aug": # TODO find a way to truly automate this
+ feat_dim = pretrained_config.additional_feat_dim
+ elif args.features == "wrfeat":
+ feat_dim = WRFeatures.PROPERTIES_DIMS[DEFAULT_PROPERTIES][args.ndim]
+ else:
+ feat_dim = CTCData.get_feat_dim(args.features, args.ndim)
+
+ pretrained_feat_dim = 0 if pretrained_config is None else pretrained_config.feat_dim
+
model = TrackingTransformer(
# coord_dim=datasets["train"].datasets[0].ndim,
coord_dim=args.ndim,
# feat_dim=datasets["train"].datasets[0].feat_dim,
- # FIXME hardcoded feat_dim
feat_dim=feat_dim,
+ pretrained_feat_dim=pretrained_feat_dim,
+ reduced_pretrained_feat_dim=args.reduced_pretrained_feat_dim,
d_model=args.d_model,
pos_embed_per_dim=args.pos_embed_per_dim,
feat_embed_per_dim=args.feat_embed_per_dim,
@@ -951,6 +1083,8 @@ def train(args):
attn_positional_bias_n_spatial=args.attn_positional_bias_n_spatial,
attn_dist_mode=args.attn_dist_mode,
causal_norm=args.causal_norm,
+ disable_xy_coords=args.disable_xy_coords,
+ disable_all_coords=args.disable_all_coords,
)
model_lightning = WrappedLightningModule(
@@ -963,6 +1097,8 @@ def train(args):
tracking_frequency=args.tracking_frequency,
batch_val_tb_idx=batch_val_tb_idx,
div_upweight=args.div_upweight,
+ # per_param_clipping=args.clip_grad_per_param,
+ weight_decay=args.weight_decay,
)
# Compiling does not work!
# model_lightning = torch.compile(model_lightning)
@@ -983,13 +1119,14 @@ def train(args):
tracking_frequency=args.tracking_frequency,
batch_val_tb_idx=batch_val_tb_idx,
div_upweight=args.div_upweight,
+ weight_decay=args.weight_decay,
)
else:
logging.warning(f"No checkpoint found in {logdir}")
model_lightning.to(device)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- logging.info(f"Model has {numerize.numerize(num_params)} parameters")
+ logging.info(f"Model has {humanize.intword(num_params)} parameters")
if args.distributed:
strategy = "ddp"
@@ -1003,16 +1140,24 @@ def train(args):
else:
profiler = None
+ import platform
+
+ from lightning.pytorch.strategies import DDPStrategy
+
+ if platform.system() == "Windows":
+ strategy = DDPStrategy(process_group_backend="gloo")
+
trainer = pl.Trainer(
- accelerator="cuda",
+ accelerator="gpu" if torch.cuda.is_available() else "cpu",
strategy=strategy,
- devices=n_gpus,
+ devices=n_gpus if torch.cuda.is_available() else 1,
precision="16-mixed" if args.mixedp else 32,
logger=train_logger,
num_nodes=1,
max_epochs=args.epochs,
callbacks=callbacks,
profiler=profiler,
+ gradient_clip_val=1.0,
)
t = default_timer()
@@ -1030,7 +1175,7 @@ def train(args):
print(f"Time elapsed: {(default_timer() - t) / 60:.02f} min")
print(f"CPU Memory used: {(_process_memory() - memory) / 1e9:.2f} GB")
print(f"GPU Memory used : {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
-
+
return locals()
@@ -1045,7 +1190,8 @@ def parse_train_args():
"--config",
is_config_file=True,
help="config file path",
- default="configs/vanvliet.yaml",
+ # default="configs/vanvliet.yaml",
+ default=str(Path("/home/achard/trackastra/scripts/example_config.yaml").resolve()),
)
parser.add_argument("-o", "--outdir", type=str, default="runs")
parser.add_argument("--name", type=str, help="Name to append to timestamp")
@@ -1105,14 +1251,7 @@ def parse_train_args():
parser.add_argument(
"--features",
type=str,
- choices=[
- "none",
- "regionprops",
- "regionprops2",
- "patch",
- "patch_regionprops",
- "wrfeat",
- ],
+ choices=list(CTCData.VALID_FEATURES),
default="wrfeat",
)
parser.add_argument(
@@ -1196,6 +1335,76 @@ def parse_train_args():
" imbalance)"
),
)
+ # Pretrained feats + extra arguments
+ parser.add_argument(
+ "--pretrained_feats_model",
+ type=str,
+ choices=list(AVAILABLE_PRETRAINED_BACKBONES.keys()),
+ default=None,
+ help="If mode is pretrained_feats, specify the model to use for feature extraction",
+ )
+ parser.add_argument(
+ "--pretrained_model_path",
+ type=str,
+ default=None,
+ help="Path to pretrained model to use for feature extraction. Only valid if features is pretrained_feats.",
+ )
+ parser.add_argument(
+ "--pretrained_feats_mode",
+ type=str,
+ # choices=["nearest_patch", "mean_patches_bbox", "max_patches_bbox", "mean_patches_exact", "max_patches_exact"],
+ choices=list(PretrainedFeatsExtractionMode.__args__),
+ default=None,
+ help="If mode is pretrained_feats, specify the mode to use for feature extraction",
+ )
+ parser.add_argument(
+ "--pretrained_feats_additional_props",
+ type=str,
+ choices=list(_PROPERTIES.keys()),
+ default=None,
+ help="Additional regionprops features to use in addition to pretrained model embeddings",
+ )
+ # parser.add_argument(
+ # "--pretrained_feats_pca_ncomp",
+ # type=int,
+ # default=None,
+ # help="Number of components to use for PCA dimensionality reduction. If None, no PCA is applied.",
+ # )
+ parser.add_argument(
+ "--weight_decay",
+ type=float,
+ default=0.01,
+ help="Weight decay for the AdamW optimizer",
+ )
+ parser.add_argument(
+ "--pretrained_n_augs",
+ type=int,
+ default=None,
+ help="Number of augmentations to use for pretrained features. Only valid if features is pretrained_feats_aug",
+ )
+ parser.add_argument(
+ "--disable_xy_coords",
+ type=str2bool,
+ default=False,
+ help="Disable x and y coordinates as input features. --features cannot be none if True.",
+ )
+ parser.add_argument(
+ "--disable_all_coords",
+ type=str2bool,
+ default=False,
+ help="Disable all coordinates T(Z)XY as input features. --features cannot be none if True.",
+ )
+ parser.add_argument(
+ "--rotate_features",
+ type=str2bool,
+ default=False,
+ help="Rotate features using augmented coordinates. features must be 'pretrained_feats' or 'pretrained_feats_aug' if True.",
+ )
+ parser.add_argument(
+ "--reduced_pretrained_feat_dim",
+ type=int,
+ default=128,
+ )
args, unknown_args = parser.parse_known_args()
diff --git a/scripts/wrfeat_aug_bacteria.yaml b/scripts/wrfeat_aug_bacteria.yaml
new file mode 100644
index 0000000..5a01feb
--- /dev/null
+++ b/scripts/wrfeat_aug_bacteria.yaml
@@ -0,0 +1,96 @@
+# Forked from zih_bacteria.yaml
+name: vanvliet_wrfeat
+# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug
+epochs: 800
+warmup_epochs: 5
+window: 4
+attn_dist_mode: v1
+delta_cutoff: 1
+num_encoder_layers: 4
+num_decoder_layers: 4
+causal_norm: none
+d_model: 256
+# dropout: 0.05
+pos_embed_per_dim: 32
+lr: 0.0001
+train_samples: 32000
+max_tokens: 2048
+batch_size: 64
+detection_folders:
+- TRA
+crop_size:
+- 320
+- 320
+attn_positional_bias: rope
+ndim: 2
+# Features config
+# features: pretrained_feats_aug
+# features: pretrained_feats
+features: wrfeat
+#### Additional parameters for pretrained_feats
+augment: 3
+# pretrained_n_augs: 25
+# pretrained_feats_model: facebook/sam2.1-hiera-base-plus
+# pretrained_feats_model: facebookresearch/co-tracker
+# pretrained_feats_model: weigertlab/tarrow
+# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2
+# pretrained_feats_model: debug/encoded_labels
+# reduced_pretrained_feat_dim: 128
+# pretrained_feats_mode: mean_patches_exact
+# pretrained_feats_mode: nearest_patch
+# pretrained_feats_additional_props: regionprops_small
+# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding
+# rotate_features: true
+# disable_xy_coords: true
+# disable_all_coords: True
+cachedir: /backup/achard/cache
+outdir: /home/achard/trastra_v2
+# weight_decay: 0.01
+#### Logger
+logger: wandb
+wandb_project: "trackastra_v2"
+# Paths config
+cache: true
+compress: true
+distributed: false
+###
+
+input_train:
+- /backup/achard/CTC_DATA/vanvliet/cib/140409-03 #
+- /backup/achard/CTC_DATA/vanvliet/cib/140415-08
+- /backup/achard/CTC_DATA/vanvliet/recA/151027-05
+- /backup/achard/CTC_DATA/vanvliet/recA/151027-06
+# - /backup/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying
+- /backup/achard/CTC_DATA/vanvliet/recA/151028-01
+- /backup/achard/CTC_DATA/vanvliet/recA/151029-05
+- /backup/achard/CTC_DATA/vanvliet/recA/151029-11
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12
+- /backup/achard/CTC_DATA/vanvliet/metA/150317-07
+- /backup/achard/CTC_DATA/vanvliet/metA/150318-06
+- /backup/achard/CTC_DATA/vanvliet/metA/150331-12
+- /backup/achard/CTC_DATA/vanvliet/metA/151222-10
+- /backup/achard/CTC_DATA/vanvliet/metA/151222-11
+- /backup/achard/CTC_DATA/vanvliet/pheA/150324-03
+- /backup/achard/CTC_DATA/vanvliet/pheA/150324-05
+- /backup/achard/CTC_DATA/vanvliet/pheA/150325-04
+- /backup/achard/CTC_DATA/vanvliet/pheA/160112-04
+- /backup/achard/CTC_DATA/vanvliet/trpL/150303-01
+- /backup/achard/CTC_DATA/vanvliet/trpL/150303-08
+
+input_val:
+- /backup/achard/CTC_DATA/vanvliet/trpL/150428-08
+- /backup/achard/CTC_DATA/vanvliet/trpL/151021-11
+
+input_test:
+# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok
+- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok
+- /backup/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok
+# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty
+- /backup/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving!
+- /backup/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set.
diff --git a/setup.cfg b/setup.cfg
index 8afd34c..e4b3d2b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -19,11 +19,13 @@ classifiers =
[options]
packages = find:
install_requires =
- numpy<2
+ numpy
matplotlib
scipy
+ h5py
pandas
- numerize
+ dask
+ humanize
configargparse
tensorboard
pyyaml
@@ -34,6 +36,7 @@ install_requires =
kornia>=0.7.0 # TODO remove old augs
torch
torchvision
+ # transformers
lz4
imagecodecs>=2023.7.10
wandb
@@ -55,6 +58,12 @@ dev =
build
test =
pytest
+pretrained_feats =
+ transformers
+ sam-2 @ git+https://github.com/facebookresearch/sam2.git
+ segment-anything @ git+https://github.com/facebookresearch/segment-anything.git
+ scikit-learn
+ zarr
[options.entry_points]
console_scripts =
diff --git a/tests/test_attn.py b/tests/test_attn.py
new file mode 100644
index 0000000..ef09f3e
--- /dev/null
+++ b/tests/test_attn.py
@@ -0,0 +1,38 @@
+import torch
+from trackastra.model.model_parts import RelativePositionalAttention
+
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+if __name__ == "__main__":
+ model = RelativePositionalAttention(coord_dim=2,
+ embed_dim=64,
+ n_head=2,
+ mode='rope',
+ attn_dist_mode='v2')
+
+ model.eval()
+ model.to(device)
+
+ B,N = 3,11
+ q = torch.rand(B,N,64).to(device)
+ x = torch.rand(B,N,3).to(device)
+
+ pad_mask = torch.zeros((B, N), dtype=torch.bool).to(device)
+ pad_mask[0,-2:] = True
+ pad_mask[1,-3:] = True
+ pad_mask[2,-4:] = True
+
+
+
+ u = model(q,q,q,coords=x, padding_mask=pad_mask)
+ mask = model.attn_mask
+
+ u1 = model(q[:1],q[:1],q[:1],coords=x[:1],padding_mask=pad_mask[:1])
+ mask1 = model.attn_mask
+
+ err = torch.abs(u[:1] - u1).mean()
+ print(f'Error: {err:.4f}')
+ print('close: ', torch.allclose(u[:1],u1, rtol=1e-3, atol=1e-6))
+
+
\ No newline at end of file
diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py
index 153c485..89078fb 100644
--- a/tests/test_augmentations.py
+++ b/tests/test_augmentations.py
@@ -1,7 +1,7 @@
import numpy as np
import pytest
from scipy.ndimage import maximum_filter
-from trackastra.data import AugmentationPipeline
+from trackastra.data import AugmentationPipeline, CTCData
def plot_augs(b1, b2):
@@ -54,6 +54,8 @@ def test_augpipeline(plot=False):
if __name__ == "__main__":
+
+
test_augpipeline(plot=True)
# pipe = RandomCrop((30, 40, 10, 20), ensure_inside_points=True)
diff --git a/trackastra/data/__init__.py b/trackastra/data/__init__.py
index fbddf06..be4c89c 100644
--- a/trackastra/data/__init__.py
+++ b/trackastra/data/__init__.py
@@ -14,5 +14,31 @@
BalancedDistributedSampler,
)
from .example_data import example_data_bacteria, example_data_fluo_3d, example_data_hela
-from .utils import filter_track_df, load_tiff_timeseries, load_tracklet_links
+from .pretrained_augmentations import (
+ PretrainedAugmentations,
+ PretrainedIntensityAugmentations,
+ PretrainedMovementAugmentations,
+)
+from .pretrained_features import (
+ CellposeSAMFeatures,
+ CoTrackerFeatures,
+ DinoV2Features,
+ FeatureExtractor,
+ FeatureExtractorAugWrapper,
+ HieraFeatures,
+ MicroSAMFeatures,
+ PretrainedBackboneType,
+ PretrainedFeatsExtractionMode,
+ PretrainedFeatureExtractorConfig,
+ SAM2Features,
+ SAM2HighresFeatures,
+ SAMFeatures,
+ TAPFeatures,
+)
+from .utils import (
+ filter_track_df,
+ load_tiff_timeseries,
+ load_tracklet_links,
+ make_hashable,
+)
from .wrfeat import WRFeatures, build_windows, get_features
diff --git a/trackastra/data/data.py b/trackastra/data/data.py
index 9df93fa..b1fde4c 100644
--- a/trackastra/data/data.py
+++ b/trackastra/data/data.py
@@ -1,8 +1,13 @@
+from __future__ import annotations
+
+import hashlib
+import json
import logging
from collections.abc import Sequence
+from functools import lru_cache
from pathlib import Path
from timeit import default_timer
-from typing import Literal
+from typing import TYPE_CHECKING, ClassVar, Literal
import joblib
import lz4.frame
@@ -11,6 +16,7 @@
import pandas as pd
import tifffile
import torch
+import zarr
from numba import njit
from scipy import ndimage as ndi
from scipy.spatial.distance import cdist
@@ -32,13 +38,22 @@
extract_features_regionprops,
)
from trackastra.data.matching import matching
+from trackastra.data.pretrained_augmentations import PretrainedAugmentations
+from trackastra.utils import blockwise_sum, masks2properties, normalize
-# from ..utils import blockwise_sum, normalize
-from trackastra.utils import blockwise_sum, normalize
+from .utils import make_hashable
logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
+# logger.setLevel(logging.INFO)
+logger.setLevel(logging.DEBUG) # FIXME go back to INFO for release
+
+if TYPE_CHECKING:
+ from trackastra.data.pretrained_features import (
+ PretrainedBackboneType,
+ PretrainedFeatsExtractionMode,
+ PretrainedFeatureExtractorConfig,
+ )
+
def _filter_track_df(df, start_frame, end_frame, downscale):
"""Only keep tracklets that are present in the given time interval."""
@@ -115,6 +130,53 @@ def wrapper(*args, **kwargs):
class CTCData(Dataset):
+ """Cell Tracking Challenge data loader."""
+ # Amount of feature per mode per dimension
+ FEATURES_DIMENSIONS: ClassVar = {
+ "regionprops": {
+ 2: 7,
+ 3: 11,
+ },
+ "regionprops2": {
+ 2: 6,
+ 3: 11,
+ },
+ "regionprops_full": {
+ 2: 9,
+ 3: 13,
+ },
+ "patch": {
+ 2: 256,
+ 3: 256,
+ },
+ "patch_regionprops": {
+ 2: 256 + 5,
+ 3: 256 + 8,
+ },
+ "none": {
+ 2: 0,
+ 3: 0,
+ }
+ # "wrfeat" -> defined by wrfeat
+ # "pretrained_feats":{ # -> defined by PretrainedFeatureExtractorConfig.feat_dim
+ }
+ VALID_FEATURES: ClassVar = {
+ "none",
+ "regionprops",
+ "regionprops2",
+ "patch",
+ "patch_regionprops",
+ "wrfeat",
+ "pretrained_feats",
+ "pretrained_feats_aug",
+ }
+
+ def __new__(cls, *args, **kwargs):
+ # Check if features is "pretrained_feats_aug"; if it is, use CTCDataAugPretrainedFeats class
+ if kwargs.get("features") == "pretrained_feats_aug":
+ return super().__new__(globals()["CTCDataAugPretrainedFeats"])
+ return super().__new__(cls)
+
def __init__(
self,
root: str = "",
@@ -134,40 +196,59 @@ def __init__(
"patch",
"patch_regionprops",
"wrfeat",
+ "pretrained_feats",
+ "pretrained_feats_aug",
] = "wrfeat",
sanity_dist: bool = False,
crop_size: tuple | None = None,
return_dense: bool = False,
compress: bool = False,
+ pretrained_backbone_config: PretrainedFeatureExtractorConfig | None = None,
+ # pca_preprocessor: EmbeddingsPCACompression | None = None,
+ rotate_features: bool = False,
+ load_immediately: bool = True,
**kwargs,
) -> None:
- """_summary_.
-
- Args:
- root (str):
- Folder containing the CTC TRA folder.
- ndim (int):
- Number of dimensions of the data. Defaults to 2d
- (if ndim=3 and data is two dimensional, it will be cast to 3D)
- detection_folders:
- List of relative paths to folder with detections.
- Defaults to ["TRA"], which uses the ground truth detections.
- window_size (int):
- Window size for transformer.
- slice_pct (tuple):
- Slice the dataset by percentages (from, to).
- augment (int):
- if 0, no data augmentation. if > 0, defines level of data augmentation.
- features (str):
- Types of features to use.
- sanity_dist (bool):
- Use euclidian distance instead of the association matrix as a target.
- crop_size (tuple):
- Size of the crops to use for augmentation. If None, no cropping is used.
- return_dense (bool):
- Return dense masks and images in the data samples.
- compress (bool):
- Compress elements/remove img if not needed to save memory for large datasets
+ """Args:
+ root (str):
+ Folder containing the CTC TRA folder.
+ ndim (int):
+ Number of dimensions of the data. Defaults to 2d
+ (if ndim=3 and data is two dimensional, it will be cast to 3D)
+ detection_folders:
+ List of relative paths to folder with detections.
+ Defaults to ["TRA"], which uses the ground truth detections.
+ window_size (int):
+ Window size for transformer.
+ slice_pct (tuple):
+ Slice the dataset by percentages (from, to).
+ augment (int):
+ if 0, no data augmentation. if > 0, defines level of data augmentation.
+ features (str):
+ Types of features to use.
+ sanity_dist (bool):
+ Use euclidian distance instead of the association matrix as a target.
+ crop_size (tuple):
+ Size of the crops to use for augmentation. If None, no cropping is used.
+ return_dense (bool):
+ Return dense masks and images in the data samples.
+ compress (bool):
+ Compress elements/remove img if not needed to save memory for large datasets
+ pretrained_backbone_config (PretrainedFeatureExtractorConfig):
+ Configuration for the pretrained backbone.
+ If mode is set to "pretrained_feats", this configuration is used to extract features.
+ Ignored otherwise.
+ rotate_features (bool):
+ Apply rotation to features based on (augmented) coordinates.
+ Only valid if used with "pretrained_feats" or "pretrained_feats_aug" mode.
+ load_immediately (bool):
+ If True, load the data immediately. If False, load the data lazily.
+ If False, you need to call `start_loading()` to load the data.
+ Defaults to True.
+ # pca_preprocessor (EmbeddingsPCACompression):
+ # PCA preprocessor for the pretrained features.
+ # If mode is set to "pretrained_feats", this is used to reduce the dimensionality of the features.
+ # Ignored otherwise.
"""
super().__init__()
@@ -180,14 +261,29 @@ def __init__(
self.downscale_spatial = downscale_spatial
self.downscale_temporal = downscale_temporal
self.detection_folders = detection_folders
- self.ndim = ndim
+ self._ndim = ndim
self.features = features
+ self.rotate_feats = rotate_features
- if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]:
- raise ValueError(
- f"'{features}' not one of the supported {ndim}D features"
- f" {tuple(_PROPERTIES[ndim].keys())}"
- )
+ if features not in self.VALID_FEATURES:
+ if features not in _PROPERTIES[self._ndim] and features != "wrfeat":
+ raise ValueError(
+ f"'{features}' not one of the supported {self._ndim}D features"
+ f" {tuple(_PROPERTIES[self._ndim].keys())}"
+ )
+
+ if features == "pretrained_feats" or features == "pretrained_feats_aug":
+ try:
+ if TYPE_CHECKING:
+ import transformers
+ transformers.__version__
+ except ImportError as e:
+ msg = """Please install pretrained_feats extra requirements to use pretrained features mode.\n
+ Run :
+ pip install trackastra[pretrained_feats]
+ to install the required dependencies.
+ """
+ raise ImportError(msg) from e
logger.info(f"ROOT (config): \t{self.root}")
self.root, self.gt_tra_folder = self._guess_root_and_gt_tra_folder(self.root)
@@ -209,9 +305,18 @@ def __init__(
self.img_folder = self._guess_img_folder(self.root)
logger.info(f"IMG (guessed):\t{self.img_folder}")
- self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs(
- ndim, features, augment, crop_size
- )
+ self._pretrained_config = None
+ if features == "pretrained_feats" or features == "pretrained_feats_aug":
+ if pretrained_backbone_config is None:
+ raise ValueError("Pretrained backbone config must be provided for pretrained features mode.")
+ self.pretrained_config = pretrained_backbone_config
+ if self.pretrained_config.save_path is None:
+ self.pretrained_config.save_path = self.img_folder
+ self.FEATURES_DIMENSIONS["pretrained_feats"] = self.pretrained_config.feat_dim
+
+ self.augment_level = augment
+ self.crop_size = crop_size
+ self.augmenter, self.cropper = self._setup_features_augs()
if window_size <= 1:
raise ValueError("window must be >1")
@@ -224,35 +329,139 @@ def __init__(
self.compress = compress
self.start_frame = 0
self.end_frame = None
-
+
+ # Pretrained model attributes for feature extraction if specified
+ self._pretrained_model_input_size_factor = 1
+ self.feature_extractor_input_size = None
+ self.feature_extractor_save_path = None
+ self.pretrained_features = None
+ self.feature_extractor = None
+ self.pretrained_feature_augmenter = None
+ # self.pca_preprocessor = pca_preprocessor
+
+ if load_immediately:
+ self.start_loading()
+
+ if kwargs:
+ logger.warning(f"Unused kwargs: {kwargs}")
+
+ def start_loading(self):
start = default_timer()
- if self.features == "wrfeat":
- self.windows = self._load_wrfeat()
- else:
- self.windows = self._load()
+ self._init_features() # loads and creates windows
self.n_divs = self._get_ndivs(self.windows)
- if len(self.windows) > 0:
- self.ndim = self.windows[0]["coords"].shape[1]
- self.n_objects = tuple(len(t["coords"]) for t in self.windows)
- logger.info(
- f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track"
- f" windows from {self.root} ({default_timer() - start:.1f}s)\n"
- )
- else:
- self.n_objects = 0
- logger.warning(f"Could not load any tracks from {self.root}")
+ # if len(self.windows) > 0:
+ # self.ndim = self.windows[0]["coords"].shape[1]
+ # self.n_objects = tuple(len(t["coords"]) for t in self.windows)
+ # logger.info(
+ # f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track"
+ # f" windows from {self.root} ({default_timer() - start:.1f}s)\n"
+ # )
+ # else:
+ # self.n_objects = 0
+ # logger.warning(f"Could not load any tracks from {self.root}")
+ self._get_ndim_and_nobj(start)
if self.compress:
self._compress_data()
- # def from_ctc
-
+ def _init_features(self):
+ if self.features == "wrfeat" or self.features == "pretrained_feats":
+ self.windows = self._load_wrfeat()
+ else:
+ self.windows = self._load()
+
+ @property
+ def config(self):
+ return {
+ "root": str(self.root),
+ "ndim": self.ndim,
+ "use_gt": self.use_gt,
+ "detection_folders": self.detection_folders,
+ "window_size": self.window_size,
+ "max_tokens": self.max_tokens,
+ "slice_pct": self.slice_pct,
+ "downscale_spatial": self.downscale_spatial,
+ "downscale_temporal": self.downscale_temporal,
+ "augment": self.augment_level,
+ "features": self.features,
+ "sanity_dist": self.sanity_dist,
+ "crop_size": self.crop_size,
+ "return_dense": self.return_dense,
+ "compress": self.compress,
+ "pretrained_config": (
+ self.pretrained_config.to_dict() if self.pretrained_config else None
+ ),
+ "rotate_features": self.rotate_feats,
+ }
+
+ @property
+ def config_hash(self):
+ """Returns a hash of the configuration."""
+ cfg = make_hashable(self.config)
+ config_str = json.dumps(cfg, sort_keys=True)
+ return hashlib.sha256(config_str.encode()).hexdigest()
+
+ @property
+ def ndim(self):
+ return self._ndim
+
+ @ndim.setter
+ def ndim(self, value: int):
+ if value not in (2, 3):
+ raise ValueError(f"ndim must be 2 or 3, got {value}")
+ self._ndim = value
+
+ @property
+ def feat_dim(self):
+ if self.pretrained_config is None:
+ return self.FEATURES_DIMENSIONS[self.features][self.ndim]
+ elif self.features == "wrfeat":
+ return wrfeat.WRFeatures.PROPERTIES_DIMS[
+ wrfeat.DEFAULT_PROPERTIES
+ ][self.ndim]
+ else:
+ return self.pretrained_config.additional_feat_dim
+
+ @property
+ def pretrained_config(self):
+ return self._pretrained_config
+
+ @property
+ def pretrained_feat_dim(self):
+ if self._pretrained_config is None:
+ return 0
+ return self._pretrained_config.feat_dim
+
+ @pretrained_config.setter
+ def pretrained_config(self, config: PretrainedFeatureExtractorConfig):
+ if isinstance(config, dict):
+ from trackastra.data.pretrained_features import (
+ PretrainedFeatureExtractorConfig,
+ )
+ config = PretrainedFeatureExtractorConfig.from_dict(config)
+ self.update_pretrained_feat_dim(config)
+ self._pretrained_config = config
+
+ def update_pretrained_feat_dim(self, config):
+ try:
+ self.FEATURES_DIMENSIONS["pretrained_feats"] = config.feat_dim
+ except AttributeError as e:
+ if isinstance(config, dict):
+ self.FEATURES_DIMENSIONS["pretrained_feats"] = config["feat_dim"]
+ else:
+ raise e
+
+ @staticmethod
+ def get_feat_dim(features, ndim, ):
+ return CTCData.FEATURES_DIMENSIONS[features][ndim]
+
@classmethod
def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict):
self = cls(**train_args)
+ # def from_ctc
# for key, value in train_args.items():
# setattr(self, key, value)
@@ -291,13 +500,11 @@ def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict):
# self.img_folder = self._guess_img_folder(self.root)
# logger.info(f"IMG:\t\t{self.img_folder}")
- self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs(
- self.ndim, self.features, self.augment, self.crop_size
- )
+ self.augmenter, self.cropper = self._setup_features_augs()
start = default_timer()
- if self.features == "wrfeat":
+ if self.features == "wrfeat" or self.features == "pretrained_feats":
self.windows = self._load_wrfeat()
else:
self.windows = self._load()
@@ -317,7 +524,19 @@ def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict):
if self.compress:
self._compress_data()
-
+
+ def _get_ndim_and_nobj(self, start):
+ if len(self.windows) > 0:
+ self.ndim = self.windows[0]["coords"].shape[1]
+ self.n_objects = tuple(len(t["coords"]) for t in self.windows)
+ logger.info(
+ f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track"
+ f" windows from {self.root} ({default_timer() - start:.1f}s)\n"
+ )
+ else:
+ self.n_objects = 0
+ logger.warning(f"Could not load any tracks from {self.root}")
+
def _get_ndivs(self, windows):
n_divs = []
for w in tqdm(windows, desc="Counting divisions", leave=False):
@@ -336,59 +555,47 @@ def _get_ndivs(self, windows):
return n_divs
def _setup_features_augs(
- self, ndim: int, features: str, augment: int, crop_size: tuple[int]
+ self
):
- if self.features == "wrfeat":
- return self._setup_features_augs_wrfeat(ndim, features, augment, crop_size)
+ if self.features in ["wrfeat", "pretrained_feats"]:
+ return self._setup_features_augs_wrfeat()
cropper = (
RandomCrop(
- crop_size=crop_size,
- ndim=ndim,
+ crop_size=self.crop_size,
+ ndim=self.ndim,
use_padding=False,
ensure_inside_points=True,
)
- if crop_size is not None
+ if self.crop_size is not None
else None
)
# Hack
if self.features == "none":
- return 0, default_augmenter, cropper
-
- if ndim == 2:
- augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None
- feat_dim = {
- "none": 0,
- "regionprops": 7,
- "regionprops2": 6,
- "patch": 256,
- "patch_regionprops": 256 + 5,
- }[features]
- elif ndim == 3:
- augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None
- feat_dim = {
- "none": 0,
- "regionprops2": 11,
- "patch_regionprops": 256 + 8,
- }[features]
-
- return feat_dim, augmenter, cropper
+ return default_augmenter, cropper
+
+ augmenter = AugmentationPipeline(p=0.8, level=self.augment_level) if self.augment_level else None
+
+ return augmenter, cropper
def _compress_data(self):
# compress masks and assoc_matrix
logger.info("Compressing masks and assoc_matrix to save memory")
for w in self.windows:
- w["mask"] = _CompressedArray(w["mask"])
+ if "mask" in w:
+ w["mask"] = _CompressedArray(w["mask"])
# dont compress full imgs (as needed for patch features)
- w["img"] = _CompressedArray(w["img"])
+ if "img" in w:
+ w["img"] = _CompressedArray(w["img"])
w["assoc_matrix"] = _CompressedArray(w["assoc_matrix"])
self.gt_masks = _CompressedArray(self.gt_masks)
self.det_masks = {k: _CompressedArray(v) for k, v in self.det_masks.items()}
# dont compress full imgs (as needed for patch features)
self.imgs = _CompressedArray(self.imgs)
- def _guess_root_and_gt_tra_folder(self, inp: Path):
+ @staticmethod
+ def _guess_root_and_gt_tra_folder(inp: Path):
"""Guesses the root and the ground truth folder from a given input path.
Args:
@@ -410,15 +617,17 @@ def _guess_root_and_gt_tra_folder(self, inp: Path):
tra = ctc_tra if ctc_tra.exists() else inp / "TRA"
# 01 --> 01, 01_GT/TRA or 01/TRA
return inp, tra
-
- def _guess_img_folder(self, root: Path):
+
+ @staticmethod
+ def _guess_img_folder(root: Path):
"""Guesses the image folder corresponding to a root."""
if (root / "img").exists():
return root / "img"
else:
return root
- def _guess_mask_folder(self, root: Path, gt_tra: Path):
+ @staticmethod
+ def _guess_mask_folder(root: Path, gt_tra: Path):
"""Guesses the mask folder corresponding to a root.
In CTC format, we use silver truth segmentation masks.
@@ -452,7 +661,7 @@ def _guess_det_folder(cls, root: Path, suffix: str):
def __len__(self):
return len(self.windows)
-
+
def _load_gt(self):
logger.info("Loading ground truth")
self.start_frame = int(
@@ -539,7 +748,7 @@ def _load_tiffs(self, folder: Path, dtype=None):
logger.debug(f"Loaded array of shape {x.shape} from {folder}")
return x
- def _masks2properties(self, masks):
+ def _masks2properties(self, masks, return_props_by_time=False):
"""Turn label masks into lists of properties, sorted (ascending) by time and label id.
Args:
@@ -550,38 +759,7 @@ def _masks2properties(self, masks):
ts: List of timepoints
coords: List of coordinates
"""
- # Get coordinates, timepoints, and labels of detections
- labels = []
- ts = []
- coords = []
- properties_by_time = dict()
- assert len(self.imgs) == len(masks)
- for _t, frame in tqdm(
- enumerate(masks),
- # total=len(detections),
- leave=False,
- desc="Loading masks and properties",
- ):
- regions = regionprops(frame)
- t_labels = []
- t_ts = []
- t_coords = []
- for _r in regions:
- t_labels.append(_r.label)
- t_ts.append(_t)
- centroid = np.array(_r.centroid).astype(int)
- t_coords.append(centroid)
-
- properties_by_time[_t] = dict(coords=t_coords, labels=t_labels)
- labels.extend(t_labels)
- ts.extend(t_ts)
- coords.extend(t_coords)
-
- labels = np.array(labels, dtype=int)
- ts = np.array(ts, dtype=int)
- coords = np.array(coords, dtype=int)
-
- return labels, ts, coords, properties_by_time
+ return masks2properties(self.imgs, masks, return_props_by_time=return_props_by_time)
def _load_tracklet_links(self, folder: Path) -> pd.DataFrame:
df = pd.read_csv(
@@ -619,16 +797,17 @@ def _check_dimensions(self, x: np.ndarray):
raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data")
return x
- def _load(self):
+ def _prepare_masks_and_imgs(self, return_orig_imgs=False):
# Load ground truth
- logger.info("Loading ground truth")
self.gt_masks, self.gt_track_df = self._load_gt()
-
self.gt_masks = self._check_dimensions(self.gt_masks)
# Load images
if self.img_folder is None:
- self.imgs = np.zeros_like(self.gt_masks)
+ if self.gt_masks is not None:
+ self.imgs = np.zeros_like(self.gt_masks)
+ else:
+ raise NotImplementedError("No images and no GT masks")
else:
logger.info("Loading images")
imgs = self._load_tiffs(self.img_folder, dtype=np.float32)
@@ -644,8 +823,38 @@ def _load(self):
for im, mask in zip(self.imgs, self.gt_masks)
]
)
+ if np.any(np.isnan(self.imgs)):
+ raise ValueError("Compressed images contain NaN values")
assert len(self.gt_masks) == len(self.imgs)
+ if return_orig_imgs:
+ return imgs
+
+ def _load(self):
+ # # Load ground truth
+ # logger.info("Loading ground truth")
+ # self.gt_masks, self.gt_track_df = self._load_gt()
+ # self.gt_masks = self._check_dimensions(self.gt_masks)
+
+ # # Load images
+ # if self.img_folder is None:
+ # self.imgs = np.zeros_like(self.gt_masks)
+ # else:
+ # logger.info("Loading images")
+ # imgs = self._load_tiffs(self.img_folder, dtype=np.float32)
+ # self.imgs = np.stack(
+ # [normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)]
+ # )
+ # self.imgs = self._check_dimensions(self.imgs)
+ # if self.compress:
+ # # prepare images to be compressed later (e.g. removing non masked parts for regionprops features)
+ # self.imgs = np.stack(
+ # [
+ # _compress_img_mask_preproc(im, mask, self.features)
+ # for im, mask in zip(self.imgs, self.gt_masks)
+ # ]
+ # )
+ self._prepare_masks_and_imgs()
# Load each of the detection folders and create data samples with a sliding window
windows = []
@@ -662,7 +871,7 @@ def _load(self):
det_ts,
det_coords,
det_properties_by_time,
- ) = self._masks2properties(det_masks)
+ ) = self._masks2properties(det_masks, return_props_by_time=True)
det_gt_matching = {
t: {_l: _l for _l in det_properties_by_time[t]["labels"]}
@@ -684,7 +893,7 @@ def _load(self):
det_ts,
det_coords,
det_properties_by_time,
- ) = self._masks2properties(det_masks)
+ ) = self._masks2properties(det_masks, return_props_by_time=True)
# FIXME matching can be slow for big images
# raise NotImplementedError("Matching not implemented for 3d version")
@@ -703,6 +912,8 @@ def _load(self):
self.properties_by_time[_f] = det_properties_by_time
self.det_masks[_f] = det_masks
+
+ # Build windows
_w = self._build_windows(
det_folder,
det_masks,
@@ -804,10 +1015,32 @@ def _build_windows(
logger.debug(f"Built {len(windows)} track windows from {det_folder}.\n")
return windows
-
+
+ def _apply_transform_and_check(self, img, labels, mask, coords, timepoints, min_time, assoc_matrix):
+ (img2, mask2, coords2), idx = self.augmenter(
+ img, mask, coords, timepoints - min_time
+ )
+ if len(idx) > 0:
+ img, mask, coords = img2, mask2, coords2
+ labels = labels[idx]
+ timepoints = timepoints[idx]
+ assoc_matrix = assoc_matrix[idx][:, idx]
+ mask = mask.astype(int)
+ else:
+ logger.debug(
+ "Disable augmentation as no trajectories would be left"
+ )
+ return img, labels, mask, coords, timepoints, assoc_matrix
+
+ @staticmethod
+ def decompress(data):
+ if isinstance(data, _CompressedArray):
+ return data.decompress()
+ return data
+
def __getitem__(self, n: int, return_dense=None):
# if not set, use default
- if self.features == "wrfeat":
+ if self.features == "wrfeat" or self.features == "pretrained_feats":
return self._getitem_wrfeat(n, return_dense)
if return_dense is None:
@@ -822,12 +1055,15 @@ def __getitem__(self, n: int, return_dense=None):
timepoints = track["timepoints"]
min_time = track["t1"]
- if isinstance(mask, _CompressedArray):
- mask = mask.decompress()
- if isinstance(img, _CompressedArray):
- img = img.decompress()
- if isinstance(assoc_matrix, _CompressedArray):
- assoc_matrix = assoc_matrix.decompress()
+ # if isinstance(mask, _CompressedArray):
+ # mask = mask.decompress()
+ # if isinstance(img, _CompressedArray):
+ # img = img.decompress()
+ # if isinstance(assoc_matrix, _CompressedArray):
+ # assoc_matrix = assoc_matrix.decompress()
+ mask = CTCData.decompress(mask)
+ img = CTCData.decompress(img)
+ assoc_matrix = CTCData.decompress(assoc_matrix)
# cropping
if self.cropper is not None:
@@ -853,19 +1089,9 @@ def __getitem__(self, n: int, return_dense=None):
elif self.features in ("regionprops", "regionprops2"):
if self.augmenter is not None:
- (img2, mask2, coords2), idx = self.augmenter(
- img, mask, coords, timepoints - min_time
+ img, labels, mask, coords, timepoints, assoc_matrix = self._apply_transform_and_check(
+ img, labels, mask, coords, timepoints, min_time, assoc_matrix
)
- if len(idx) > 0:
- img, mask, coords = img2, mask2, coords2
- labels = labels[idx]
- timepoints = timepoints[idx]
- assoc_matrix = assoc_matrix[idx][:, idx]
- mask = mask.astype(int)
- else:
- logger.debug(
- "disable augmentation as no trajectories would be left"
- )
features = tuple(
extract_features_regionprops(
@@ -875,21 +1101,11 @@ def __getitem__(self, n: int, return_dense=None):
)
features = np.concatenate(features, axis=0)
# features = np.zeros((len(coords), self.feat_dim))
-
elif self.features == "patch":
if self.augmenter is not None:
- (img2, mask2, coords2), idx = self.augmenter(
- img, mask, coords, timepoints - min_time
+ img, labels, mask, coords, timepoints, assoc_matrix = self._apply_transform_and_check(
+ img, labels, mask, coords, timepoints, min_time, assoc_matrix
)
- if len(idx) > 0:
- img, mask, coords = img2, mask2, coords2
- labels = labels[idx]
- timepoints = timepoints[idx]
- assoc_matrix = assoc_matrix[idx][:, idx]
- mask = mask.astype(int)
- else:
- print("disable augmentation as no trajectories would be left")
-
features = tuple(
extract_features_patch(
m,
@@ -902,18 +1118,9 @@ def __getitem__(self, n: int, return_dense=None):
features = np.concatenate(features, axis=0)
elif self.features == "patch_regionprops":
if self.augmenter is not None:
- (img2, mask2, coords2), idx = self.augmenter(
- img, mask, coords, timepoints - min_time
+ img, labels, mask, coords, timepoints, assoc_matrix = self._apply_transform_and_check(
+ img, labels, mask, coords, timepoints, min_time, assoc_matrix
)
- if len(idx) > 0:
- img, mask, coords = img2, mask2, coords2
- labels = labels[idx]
- timepoints = timepoints[idx]
- assoc_matrix = assoc_matrix[idx][:, idx]
- mask = mask.astype(int)
- else:
- print("disable augmentation as no trajectories would be left")
-
features1 = tuple(
extract_features_patch(
m,
@@ -939,6 +1146,25 @@ def __getitem__(self, n: int, return_dense=None):
)
features = np.concatenate(features, axis=0)
+ # MOVED as WRFeat. See wrfeat.WRPretrainedFeatures
+ # elif self.features == "pretrained_feats":
+ # if self.augmenter is not None:
+ # img, labels, mask, coords, timepoints, assoc_matrix = self._apply_transform_and_check(
+ # img, labels, mask, coords, timepoints, min_time, assoc_matrix
+ # )
+ # ts, n_obj = np.unique(timepoints, return_counts=True)
+ # features = torch.zeros((n_obj.sum(), self.feat_dim))
+ # offset = 0
+
+ # for t, count in zip(ts, n_obj):
+ # feat = self.pretrained_features[t] # (timepoints -> (n_regions, n_features)) for a SINGLE timepoint
+ # if feat.shape[0] != count:
+ # raise ValueError(f"Feature mismatch at time {t}: expected {count}, got {feat.shape[0]}")
+ # features[offset:offset + count] = feat
+ # offset += count
+
+ # if features.shape[0] != len(timepoints):
+ # raise ValueError(f"Pretrained features shape mismatch: {features.shape[0]} != {len(timepoints)}")
# remove temporal offset and add timepoints to coords
relative_timepoints = timepoints - track["t1"]
@@ -957,14 +1183,14 @@ def __getitem__(self, n: int, return_dense=None):
)
coords0 = torch.from_numpy(coords).float()
- features = torch.from_numpy(features).float()
+ features = torch.from_numpy(features).float() if isinstance(features, np.ndarray) else features.float()
assoc_matrix = torch.from_numpy(assoc_matrix.copy()).float()
labels = torch.from_numpy(labels).long()
timepoints = torch.from_numpy(timepoints).long()
if self.augmenter is not None:
coords = coords0.clone()
- coords[:, 1:] += torch.randint(0, 256, (1, self.ndim))
+ coords[:, 1:] += torch.randint(-1024, 1024, (1, self.ndim))
else:
coords = coords0.clone()
res = dict(
@@ -983,92 +1209,51 @@ def __getitem__(self, n: int, return_dense=None):
mask = torch.from_numpy(mask.astype(int)).long()
res["mask"] = mask
-
return res
# wrfeat functions...
# TODO: refactor this as a subclass or make everything a class factory. *very* hacky this way
+ # -> updated _setup_features_augs_wrfeat to use a factory instead
def _setup_features_augs_wrfeat(
- self, ndim: int, features: str, augment: int, crop_size: tuple[int]
+ self
):
- # FIXME: hardcoded
- feat_dim = 7 if ndim == 2 else 12
- if augment == 1:
- augmenter = wrfeat.WRAugmentationPipeline(
- [
- wrfeat.WRRandomFlip(p=0.5),
- wrfeat.WRRandomAffine(
- p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1)
- ),
- # wrfeat.WRRandomBrightness(p=0.8, factor=(0.5, 2.0)),
- # wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
- ]
- )
- elif augment == 2:
- augmenter = wrfeat.WRAugmentationPipeline(
- [
- wrfeat.WRRandomFlip(p=0.5),
- wrfeat.WRRandomAffine(
- p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1)
- ),
- wrfeat.WRRandomBrightness(p=0.8),
- wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
- ]
- )
- elif augment == 3:
- augmenter = wrfeat.WRAugmentationPipeline(
- [
- wrfeat.WRRandomFlip(p=0.5),
- wrfeat.WRRandomAffine(
- p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1)
- ),
- wrfeat.WRRandomBrightness(p=0.8),
- wrfeat.WRRandomMovement(offset=(-10, 10), p=0.3),
- wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
- ]
- )
- else:
- augmenter = None
+ augmenter = wrfeat.AugmentationFactory.create_augmentation_pipeline(self.augment_level)
+ cropper = wrfeat.AugmentationFactory.create_cropper(self.crop_size, self.ndim) if self.crop_size is not None else None
- cropper = (
- wrfeat.WRRandomCrop(
- crop_size=crop_size,
- ndim=ndim,
- )
- if crop_size is not None
- else None
- )
- return feat_dim, augmenter, cropper
+ return augmenter, cropper
def _load_wrfeat(self):
- # Load ground truth
- self.gt_masks, self.gt_track_df = self._load_gt()
- self.gt_masks = self._check_dimensions(self.gt_masks)
-
- # Load images
- if self.img_folder is None:
- if self.gt_masks is not None:
- self.imgs = np.zeros_like(self.gt_masks)
- else:
- raise NotImplementedError("No images and no GT masks")
- else:
- logger.info("Loading images")
- imgs = self._load_tiffs(self.img_folder, dtype=np.float32)
- self.imgs = np.stack(
- [normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)]
- )
- self.imgs = self._check_dimensions(self.imgs)
- if self.compress:
- # prepare images to be compressed later (e.g. removing non masked parts for regionprops features)
- self.imgs = np.stack(
- [
- _compress_img_mask_preproc(im, mask, self.features)
- for im, mask in zip(self.imgs, self.gt_masks)
- ]
- )
-
- assert len(self.gt_masks) == len(self.imgs)
+ # # Load ground truth
+ # self.gt_masks, self.gt_track_df = self._load_gt()
+ # self.gt_masks = self._check_dimensions(self.gt_masks)
+
+ # # Load images
+ # if self.img_folder is None:
+ # if self.gt_masks is not None:
+ # self.imgs = np.zeros_like(self.gt_masks)
+ # else:
+ # raise NotImplementedError("No images and no GT masks")
+ # else:
+ # logger.info("Loading images")
+ # imgs = self._load_tiffs(self.img_folder, dtype=np.float32)
+ # self.imgs = np.stack(
+ # [normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)]
+ # )
+ # self.imgs = self._check_dimensions(self.imgs)
+ # if self.compress:
+ # # prepare images to be compressed later (e.g. removing non masked parts for regionprops features)
+ # self.imgs = np.stack(
+ # [
+ # _compress_img_mask_preproc(im, mask, self.features)
+ # for im, mask in zip(self.imgs, self.gt_masks)
+ # ]
+ # )
+ # if np.any(np.isnan(self.imgs)):
+ # raise ValueError("Compressed images contain NaN values")
+
+ # assert len(self.gt_masks) == len(self.imgs)
+ imgs = self._prepare_masks_and_imgs(return_orig_imgs=True)
# Load each of the detection folders and create data samples with a sliding window
windows = []
@@ -1114,13 +1299,42 @@ def _load_wrfeat(self):
self.det_masks[_f] = det_masks
# build features
-
- features = joblib.Parallel(n_jobs=8)(
- joblib.delayed(wrfeat.WRFeatures.from_mask_img)(
- mask=mask[None], img=img[None], t_start=t
+ if self.features == "pretrained_feats":
+ self._setup_pretrained_feature_extractor()
+ if np.all(self.imgs == 0):
+ raise ValueError("Images are empty. Images must be provided when using pretrained features")
+ self.feature_extractor.precompute_image_embeddings(imgs) # use NON_NORMALIZED images for pretrained features
+ # normalization is performed in the feature extractor
+ features = [
+ wrfeat.WRPretrainedFeatures.from_mask_img(
+ img=img[np.newaxis],
+ mask=mask[np.newaxis],
+ feature_extractor=self.feature_extractor,
+ t_start=t,
+ additional_properties=self.pretrained_config.additional_features
+ )
+ for t, (mask, img) in enumerate(zip(det_masks, self.imgs))
+ ]
+ for wrf in features:
+ feats = wrf.features_stacked
+ if feats is not None and np.any(np.isnan(wrf.features_stacked)):
+ raise ValueError("NaN in features")
+ if torch.cuda.is_available():
+ self.feature_extractor.embeddings = self.feature_extractor.embeddings.cpu()
+ torch.cuda.empty_cache()
+ elif self.features == "wrfeat":
+ features = joblib.Parallel(n_jobs=8)(
+ joblib.delayed(wrfeat.WRFeatures.from_mask_img)(
+ mask=mask[None], img=img[None], t_start=t
+ )
+ for t, (mask, img) in enumerate(zip(det_masks, self.imgs))
)
- for t, (mask, img) in enumerate(zip(det_masks, self.imgs))
- )
+ # features = []
+ # for t, (mask, img) in enumerate(zip(det_masks, self.imgs)):
+ # wrf = wrfeat.WRFeatures.from_mask_img(
+ # mask=mask[None], img=img[None], t_start=t
+ # )
+ # features.append(wrf)
properties_by_time = dict()
for _t, _feats in enumerate(features):
@@ -1228,26 +1442,39 @@ def _getitem_wrfeat(self, n: int, return_dense=None):
labels = labels[idx]
timepoints = timepoints[idx]
assoc_matrix = assoc_matrix[idx][:, idx]
- else:
- logger.debug("Skipping cropping")
+ # else:
+ # logger.debug("Skipping cropping")
if self.augmenter is not None:
feat = self.augmenter(feat)
-
+
coords0 = np.concatenate((feat.timepoints[:, None], feat.coords), axis=-1)
coords0 = torch.from_numpy(coords0).float()
assoc_matrix = torch.from_numpy(assoc_matrix.astype(np.float32))
- features = torch.from_numpy(feat.features_stacked).float()
+ # if self.pca_preprocessor is not None:
+ # features = self.pca_preprocessor.transform(feat.features_stacked)
+ # else:
+ features = feat.features_stacked
+ if features is not None:
+ features = torch.from_numpy(features).float()
+
labels = torch.from_numpy(feat.labels).long()
timepoints = torch.from_numpy(feat.timepoints).long()
-
+
+ pretrained_features = feat.pretrained_feats
+ if pretrained_features is not None:
+ pretrained_features = torch.from_numpy(pretrained_features).float()
+
if self.max_tokens and len(timepoints) > self.max_tokens:
time_incs = np.where(timepoints - np.roll(timepoints, 1))[0]
n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1]
timepoints = timepoints[:n_elems]
labels = labels[:n_elems]
coords0 = coords0[:n_elems]
- features = features[:n_elems]
+ if features is not None:
+ features = features[:n_elems]
+ if pretrained_features is not None:
+ pretrained_features = pretrained_features[:n_elems]
assoc_matrix = assoc_matrix[:n_elems, :n_elems]
logger.debug(
f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}"
@@ -1258,15 +1485,28 @@ def _getitem_wrfeat(self, n: int, return_dense=None):
coords[:, 1:] += torch.randint(0, 512, (1, self.ndim))
else:
coords = coords0.clone()
+
+ if self.features == "pretrained_feats" and self.rotate_feats:
+ if isinstance(img, _CompressedArray):
+ image_shape = img._shape
+ else:
+ image_shape = img.shape
+ # logger.debug(f"Rotating pretrained features with shape {pretrained_features.shape} for image shape {image_shape}")
+ pretrained_features = CTCData.rotate_features(
+ pretrained_features, coords, image_shape,
+ n_rot_dims=self.pretrained_feat_dim,
+ )
+
res = dict(
features=features,
+ pretrained_features=pretrained_features,
coords0=coords0,
coords=coords,
assoc_matrix=assoc_matrix,
timepoints=timepoints,
labels=labels,
)
-
+
if return_dense:
if all([x is not None for x in img]):
img = torch.from_numpy(img).float()
@@ -1274,7 +1514,809 @@ def _getitem_wrfeat(self, n: int, return_dense=None):
mask = torch.from_numpy(mask.astype(int)).long()
res["mask"] = mask
+
+ if features is not None:
+ if torch.any(torch.isnan(features)):
+ raise ValueError("NaN in features")
+ elif torch.any(torch.all(features == 0, dim=-1)):
+ raise ValueError("Empty features")
+
+ if pretrained_features is not None:
+ if torch.any(torch.isnan(pretrained_features)):
+ raise ValueError("NaN in pretrained features")
+ elif torch.any(torch.all(pretrained_features == 0, dim=-1)):
+ raise ValueError("Empty pretrained features")
+
+ return res
+
+ def _get_pretrained_features_save_path(self):
+ if self.pretrained_config is not None:
+ img_folder_name = "_".join(self.root.parts[-3:]) if len(self.root.parts) >= 3 else "_".join(self.root.parts)
+ img_folder_name = str(img_folder_name).replace(".", "").replace("/", "_").replace("\\", "_").replace(" ", "_")
+ return self.pretrained_config.save_path / f"embeddings/{img_folder_name}"
+
+ def _setup_pretrained_feature_extractor(self):
+ if self.ndim == 3:
+ raise ValueError("Pretrained model feature extraction is not implemented for 3D data")
+ img_shape = self.imgs.shape[-2:] # initial guess, replaced later if shape changes
+ from trackastra.data.pretrained_features import (
+ FeatureExtractor,
+ )
+ self.feature_extractor_save_path = self._get_pretrained_features_save_path()
+ # self.feature_extractor = FeatureExtractor.from_model_name(
+ # self.pretrained_config.model_name,
+ # img_shape,
+ # save_path=self.feature_extractor_save_path,
+ # mode=self.pretrained_config.mode,
+ # device=self.pretrained_config.device,
+ # additional_features=self.pretrained_config.additional_features,
+ # )
+ self.feature_extractor = FeatureExtractor.from_config(
+ self.pretrained_config,
+ image_shape=img_shape,
+ save_path=self.feature_extractor_save_path,
+ )
+ self.feature_extractor_input_size = self.feature_extractor.input_size
+
+ def _compute_pretrained_model_features(self):
+ if self.pretrained_config.model_name is None:
+ logger.warning("No pretrained model set, feature extraction not run")
+ return
+
+ try:
+ self.feature_extractor.input_mul = self._pretrained_model_input_size_factor
+ except Exception:
+ logger.warning(f"Cannot change input size for pretrained model: {self.pretrained_config.model_name}")
+ self.pretrained_features = self.feature_extractor.precompute_region_embeddings(self.imgs)
+ # dict(n_frames) : torch.Tensor(n_regions_in_frame, n_features)
+ self.feature_extractor = None
+
+ def compute_pretrained_features(self, input_size_factor: int | None = None, model: PretrainedBackboneType = None, mode: PretrainedFeatsExtractionMode = "nearest_patch"):
+ """Compute pretrained features for the dataset, if the model. input size factor or mode was changed.
+
+ Args:
+ input_size_factor (int, optional): The input size factor for the pretrained model. Defaults to None.
+ model (PretrainedBackboneType, optional): The pretrained model to use. Defaults to None.
+ mode (PretrainedFeatsExtractionMode, optional): The mode to use for feature extraction. Defaults to "nearest_patch".
+ """
+ if input_size_factor is not None:
+ self._pretrained_model_input_size_factor = input_size_factor
+ logger.debug(f"Setting input size factor to {input_size_factor}")
+ if model is not None:
+ self.pretrained_config.model_name = model
+ logger.debug(f"Setting pretrained model to {model}")
+ if mode is not None:
+ self.pretrained_config.mode = mode
+ logger.debug(f"Setting feature extraction mode to {mode}")
+ if input_size_factor is None and model is None and mode is None:
+ logger.warning("No changes in input size factor, model or mode. Skipping feature extraction.")
+ return
+ else:
+ self._compute_pretrained_model_features()
+
+ @staticmethod
+ def rotate_features(
+ features: torch.Tensor,
+ coords: torch.Tensor,
+ image_shape: tuple,
+ n_rot_dims: int | None = None,
+ skip_first: int = 0,
+ ) -> torch.Tensor:
+ """Applies a RoPE-style rotation to each feature vector based on the object's centroid.
+
+ Args:
+ features: (n_objects, hidden_state_size) tensor of features.
+ coords: (n_objects, 2) tensor of coordinates.
+ image_shape: (time, height, width) shape of the input image.
+ n_rot_dims: Number of feature dimensions to apply rotation to (must be even). If None, rotate all.
+ skip_first: Number of dimensions to skip at the beginning of the feature vector. No effect if 0.
+
+ Returns:
+ Rotated features: (n_objects, hidden_state_size)
+ """
+ import math
+ N, D = features.shape
+ assert skip_first < n_rot_dims, "skip_first must be less than n_rot_dims."
+ if n_rot_dims is None:
+ n_rot_dims = D
+ if skip_first != 0:
+ n_rot_dims = n_rot_dims - skip_first
+ assert n_rot_dims % 2 == 0, "n_rot_dims must be even for RoPE."
+ assert n_rot_dims <= D, "n_rot_dims cannot exceed feature dimension."
+
+ # Normalize x and y to [0, 1]
+ H, W = image_shape[-2], image_shape[-1]
+ x_norm = coords[:, 1] / H
+ y_norm = coords[:, 2] / W
+ # Compute two angles for x and y
+ angle_x = 2 * math.pi * x_norm
+ angle_y = 2 * math.pi * y_norm
+ # Interleave angles for each feature pair
+ angles = torch.stack([angle_x, angle_y], dim=1).repeat(1, n_rot_dims // 2)
+ angles = angles.view(N, n_rot_dims)
+ # Prepare cos/sin
+ cos = torch.cos(angles)
+ sin = torch.sin(angles)
+ # Interleave features for rotation
+ # try:
+ features_rot = features[:, skip_first:n_rot_dims + skip_first].view(N, -1, 2)
+ # except Exception:
+ # breakpoint()
+ x_feat, y_feat = features_rot[..., 0], features_rot[..., 1]
+ x_rot = x_feat * cos[:, ::2] - y_feat * sin[:, ::2]
+ y_rot = x_feat * sin[:, ::2] + y_feat * cos[:, ::2]
+ rotated_part = torch.stack([x_rot, y_rot], dim=-1).reshape(N, n_rot_dims)
+ if n_rot_dims < D:
+ rotated = torch.cat([rotated_part, features[:, n_rot_dims:]], dim=1)
+ else:
+ rotated = rotated_part
+ return rotated # (n_objects, d)
+
+
+class CTCDataAugPretrainedFeats(CTCData):
+ """CTCData with pretrained features."""
+ def __init__(
+ self,
+ pretrained_n_augmentations: int = 3,
+ n_aug_workers: int = 8,
+ force_recompute=False,
+ aug_pipeline: PretrainedAugmentations = None,
+ load_from_disk: bool = False,
+ *args,
+ **kwargs
+ ):
+ """Args:
+ root (str):
+ Folder containing the CTC TRA folder.
+ ndim (int):
+ Number of dimensions of the data. Defaults to 2d
+ (if ndim=3 and data is two dimensional, it will be cast to 3D)
+ detection_folders:
+ List of relative paths to folder with detections.
+ Defaults to ["TRA"], which uses the ground truth detections.
+ window_size (int):
+ Window size for transformer.
+ slice_pct (tuple):
+ Slice the dataset by percentages (from, to).
+ augment (int):
+ if 0, no data augmentation. if > 0, defines level of data augmentation.
+ features (str):
+ Types of features to use.
+ sanity_dist (bool):
+ Use euclidian distance instead of the association matrix as a target.
+ crop_size (tuple):
+ Size of the crops to use for augmentation. If None, no cropping is used.
+ return_dense (bool):
+ Return dense masks and images in the data samples.
+ compress (bool):
+ Compress elements/remove img if not needed to save memory for large datasets
+ pretrained_backbone_config (PretrainedFeatureExtractorConfig):
+ Configuration for the pretrained backbone.
+ If mode is set to "pretrained_feats", this configuration is used to extract features.
+ Ignored otherwise.
+ pretrained_n_augmentations (int):
+ How many augmented versions of the pretrained model embeddings to create.
+ n_aug_workers (int):
+ Number of workers to use for offline augmentation.
+ load_from_disk (bool):
+ If True, the offline augmented windows are saved to disk and sampled from there.
+ If False, all windows are loaded into RAM and sampled from there.
+ force_recompute (bool):
+ If False, previously computed offline augmentations are loaded if available.
+ # pca_preprocessor (EmbeddingsPCACompression):
+ # PCA preprocessor for the pretrained features.
+ # If mode is set to "pretrained_feats", this is used to reduce the dimensionality of the features.
+ # Ignored otherwise.
+ """
+ features = kwargs.get("features", None)
+
+ if features is not None and not features == "pretrained_feats_aug":
+ raise ValueError("This class should only be used with pretrained_feats_aug features")
+
+ self.n_augs = pretrained_n_augmentations
+ self.n_aug_workers = n_aug_workers
+ self.force_recompute = force_recompute
+ self.load_from_disk = load_from_disk
+
+ from trackastra.data.pretrained_augmentations import (
+ PretrainedMovementAugmentations,
+ )
+ self.pretrained_feats_augmenter = PretrainedMovementAugmentations(rng_seed=42) if aug_pipeline is None else aug_pipeline
+ if not isinstance(self.pretrained_feats_augmenter, PretrainedAugmentations):
+ raise ValueError(
+ f"Augmentation pipeline must be of type PretrainedAugmentations, got {type(self.pretrained_feats_augmenter)}"
+ )
+ logger.debug(self.pretrained_feats_augmenter)
+ self.augmented_feature_extractor = None
+ self.augmented_image_shapes = None # used to store the augmented image shapes, used to rotate features
+ self.save_windows = True
+
+ self._aug_embeds_file = None # stores the augmented per-object embeddings
+ self.delete_augs_after_loading = False
+ # self.window_save_path = None
+ self._last_selected = None
+ self._rng = np.random.default_rng()
+ self._len = None
+ self._debug = False
+
+ super().__init__(*args, **kwargs, load_immediately=False)
+
+ if self.load_from_disk:
+ self.window_save_path = self._get_pretrained_features_save_path() / "windows"
+ self.window_save_path.mkdir(parents=True, exist_ok=True)
+ self.window_save_path = self.window_save_path / f"{self.config_hash}.zarr"
+ logger.debug(f"Windows will be saved to {self.window_save_path}")
+ else:
+ self.window_save_path = None
+ logger.debug("Windows will be loaded into RAM")
+
+ if kwargs.get("load_immediately", True): # hook to delay loading if needed
+ self.start_loading()
+ start = default_timer()
+ else:
+ start = None
+
+ logger.debug("Loading finished, clearing feature extractors...")
+
+ # Clear pre-trained model
+ self.augmented_feature_extractor = None
+ # Clear windos as they are loaded from disk when __getitem__ is called
+ if self.load_from_disk:
+ self._get_ndim_and_nobj(start, self.windows)
+ self.windows = None
+ # Clear intermediate data
+ if self.delete_augs_after_loading and self._aug_embeds_file.exists():
+ try:
+ self._aug_embeds_file.close()
+ except Exception as e:
+ logger.warning(f"Could not close HDF5 file: {e}")
+ try:
+ self._aug_embeds_file.unlink()
+ self._aug_embeds_file = None
+ except Exception as e:
+ logger.warning(f"Could not delete file {self._aug_embeds_file}: {e}")
+ logger.info("Feature extractors cleared.")
+ else:
+ self._get_ndim_and_nobj(start, self.windows)
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ @property
+ def config(self):
+ cfg = super().config
+ cfg["pretrained_n_augmentations"] = self.n_augs
+ cfg["pretrained_augmentations"] = self.pretrained_feats_augmenter.get_signature()
+ return cfg
+
+ @property
+ def feat_dim(self):
+ return self.pretrained_config.feat_dim
+
+ def _init_features(self):
+ self.windows = self._load()
+ if self.load_from_disk:
+ self._save_windows()
+ else:
+ self._get_ndim_and_nobj(None, self.windows)
+
+ def _setup_features_augs(
+ self
+ ):
+ logger.debug(f"Creating augmentations with level {self.augment_level}")
+ augmenter = wrfeat.AugmentationFactory.create_augmentation_pipeline(self.augment_level, return_type=wrfeat.WRAugPretrainedFeatures)
+ cropper = wrfeat.AugmentationFactory.create_cropper(self.crop_size, self.ndim, return_type=wrfeat.WRAugPretrainedFeatures) if self.crop_size is not None else None
+
+ return augmenter, cropper
+
+ def _get_ndim_and_nobj(self, start=None, windows=None):
+ if windows is not None:
+ self.ndim = windows[0]["coords"][0][0].shape[0]
+ self.n_objects = tuple(len(t["coords"][0]) for t in windows)
+ if self.save_windows and self.load_from_disk:
+ return
+ if len(self.windows) > 0:
+ self.ndim = self.windows[0]["coords"][0].shape[1]
+ self.n_objects = tuple(len(t["coords"][0]) for t in self.windows)
+ if start is None:
+ logger.info(
+ f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track"
+ f" windows from {self.root}\n"
+ )
+ else:
+ logger.info(
+ f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track"
+ f" windows from {self.root} ({default_timer() - start:.1f}s)\n"
+ )
+ else:
+ self.n_objects = 0
+ logger.warning(f"Could not load any tracks from {self.root}")
+
+ @classmethod
+ def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict):
+ raise NotImplementedError()
+ # self = cls(**train_args)
+ # start = default_timer()
+
+ # self.windows = self._load()
+ # self.n_divs = self._get_ndivs()
+
+ # if len(self.windows) > 0:
+ # self.ndim = self.windows[0]["coords"][0].shape[1]
+ # self.n_objects = tuple(len(t["coords"][0]) for t in self.windows)
+ # logger.info(
+ # f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track"
+ # f" windows from {self.root} ({default_timer() - start:.1f}s)\n"
+ # )
+ # else:
+ # self.n_objects = 0
+ # logger.warning(f"Could not load any tracks from {self.root}")
+
+ # if self.compress:
+ # self._compress_data()
+
+ def _load(self):
+ all_windows = []
+ imgs = self._prepare_masks_and_imgs(return_orig_imgs=True)
+
+ # self.properties_by_time = dict()
+ self.det_masks = dict()
+ logger.info("Loading detections")
+ if len(self.detection_folders) > 1:
+ raise NotImplementedError("Pretrained aug features with several folders is not supported yet")
+
+ if self._load_windows() is not None:
+ return self.windows
+
+ for _f in self.detection_folders:
+ det_folder = self.root / _f
+
+ if det_folder == self.gt_mask_folder:
+ det_masks = self.gt_masks
+ logger.info("DET MASK:\tUsing GT masks")
+ # identity matching
+ (
+ det_labels,
+ det_ts,
+ _,
+ ) = self._masks2properties(det_masks)
+
+ det_gt_matching = {
+ t: {_l: _l for _l in set(np.unique(d)) - {0}}
+ for t, d in enumerate(det_masks)
+ }
+ else:
+ det_folder = self._guess_det_folder(root=self.root, suffix=_f)
+ if det_folder is None:
+ continue
+ logger.info(f"DET MASK (guessed):\t{det_folder}")
+ det_masks = self._load_tiffs(det_folder, dtype=np.int32)
+ det_masks = self._correct_gt_with_st(
+ det_folder, det_masks, dtype=np.int32
+ )
+ det_masks = self._check_dimensions(det_masks)
+ (
+ det_labels,
+ det_ts,
+ _,
+ ) = self._masks2properties(det_masks)
+ # FIXME matching can be slow for big images
+ # raise NotImplementedError("Matching not implemented for 3d version")
+ det_gt_matching = {
+ t: {
+ _d: _gt
+ for _gt, _d in matching(
+ self.gt_masks[t],
+ det_masks[t],
+ threshold=0.3,
+ max_distance=16,
+ )
+ }
+ for t in tqdm(range(len(det_masks)), leave=False, desc="Matching")
+ }
+
+ self.det_masks[_f] = det_masks
+ # Setup feature extractor
+ self._setup_pretrained_feature_extractor()
+
+ # Build augmentation pipeline
+ from trackastra.data.pretrained_features import FeatureExtractorAugWrapper
+ self.augmented_feature_extractor = FeatureExtractorAugWrapper(
+ extractor=self.feature_extractor,
+ augmenter=self.pretrained_feats_augmenter,
+ n_aug=self.n_augs,
+ force_recompute=self.force_recompute,
+ )
+ self._aug_embeds_file = self.augmented_feature_extractor.get_save_path()
+
+ # Compute features for all augmentations
+ augmented_dict = self.augmented_feature_extractor.compute_all_features(
+ images=imgs,
+ masks=det_masks,
+ clear_mem=not self.load_from_disk,
+ n_workers=self.n_aug_workers,
+ )
+ self.augmented_image_shapes = self.augmented_feature_extractor.image_shape_reference
+ # logger.debug(f"AUG DICT keys : {augmented_dict.keys()}")
+
+ _w = self._build_windows(
+ det_ts,
+ det_labels,
+ det_gt_matching,
+ augmented_dict
+ )
+ all_windows.extend(_w)
+
+ return all_windows
+
+ def _build_windows(self, ts, labels, matching, augmented_dict):
+ windows = []
+ window_size = self.window_size
+ n_frames = len(np.unique(ts))
+ n_entries = self.n_augs + 1
+ # augmented_dict structure :
+ # - aug_id:
+ # - metadata: dict, record of the applied augmentations and other metadata
+ # - data: the data for aug_id
+ # - t: frame between 0 and n_frames
+ # - lab: label of the detection
+ # - coords: coordinates of the detections for (t, lab)
+ # - features: dict of features for (t, lab)
+ # - feat_name_1: feature 1 for (t, lab)
+ # - ...
+ # - feat_name_n: feature n for (t, lab)
+
+ for t1, t2 in tqdm(
+ zip(range(0, n_frames), range(window_size, n_frames + 1)),
+ total=n_frames - window_size + 1,
+ leave=False,
+ desc="Building windows",
+ ):
+ idx = (ts >= t1) & (ts < t2)
+ _ts = ts[idx]
+ _labels = labels[idx]
+
+ _coords = {aug_id: [] for aug_id in range(n_entries)}
+ _features = {aug_id: {} for aug_id in range(n_entries)}
+ present_labels_per_aug = [set() for _ in range(n_entries)]
+
+ for aug_id in range(n_entries):
+ for t in range(t1, t2):
+ labels_at_t = _labels[_ts == t]
+ data = augmented_dict[str(aug_id)]["data"]
+
+ coords_at_t = []
+ for lab in labels_at_t:
+ try:
+ coords_at_t.append(data[t][lab]["coords"])
+ present_labels_per_aug[aug_id].add((t, lab))
+ except KeyError:
+ continue
+ if len(coords_at_t) == 0:
+ coords_at_t = np.zeros((0, self.ndim), dtype=int)
+ else:
+ coords_at_t = np.stack(coords_at_t, axis=0)
+ _coords[aug_id].extend(coords_at_t)
+
+ features_at_t = []
+ for lab in labels_at_t:
+ try:
+ features_at_t.append(data[t][lab]["features"])
+ except KeyError:
+ continue
+ if len(features_at_t) == 0:
+ features_at_t = {}
+ else:
+ features_at_t = [dict(f) for f in features_at_t]
+ for _f in features_at_t:
+ for k, v in _f.items():
+ if k not in _features[aug_id]:
+ _features[aug_id][k] = []
+ _features[aug_id][k].append(v)
+
+ # --- Filter labels missing in any augmentation --- #
+ # (This can happen due to downsampling in pretrained features,
+ # label has too few pixels to have any valid associated features.
+ # If this occurs for too many labels, check the data and augmentation settings.)
+ common_labels = set.intersection(*present_labels_per_aug)
+ keep_mask = np.array([(t, lab) in common_labels for t, lab in zip(_ts, _labels)])
+ if np.sum(~keep_mask) > 0:
+ missing_labels = set(
+ (t, lab) for t, lab in zip(_ts[~keep_mask], _labels[~keep_mask])
+ )
+ logger.warning(
+ f"Labels were removed from window {t1} to {t2}"
+ f", as those labels are missing in some augmentations. If this occurs for too many labels,"
+ f" check the data and ensure augmentation settings are appropriate."
+ )
+ logger.warning(f"Removed labels: {missing_labels}")
+ _labels = _labels[keep_mask]
+ _ts = _ts[keep_mask]
+ for aug_id in range(n_entries):
+ filtered_coords = []
+ filtered_features = {k: [] for k in _features[aug_id].keys()}
+ idx_counter = 0
+ for t in range(t1, t2):
+ labels_at_t = _labels[_ts == t]
+ n = len(labels_at_t)
+ filtered_coords.append(_coords[aug_id][idx_counter:idx_counter + n])
+ for k in _features[aug_id].keys():
+ filtered_features[k].extend(_features[aug_id][k][idx_counter:idx_counter + n])
+ idx_counter += n
+ _coords[aug_id] = np.concatenate(filtered_coords, axis=0) if filtered_coords else np.zeros((0, self.ndim), dtype=np.float32)
+ for k in filtered_features:
+ _features[aug_id][k] = np.array(filtered_features[k], dtype=np.float32)
+ else:
+ # No missing labels, just convert to arrays as usual
+ for aug_id in range(n_entries):
+ _coords[aug_id] = np.array(_coords[aug_id], dtype=np.float32)
+ for k, v in _features[aug_id].items():
+ _features[aug_id][k] = np.array(v, dtype=np.float32)
+
+ if len(_labels) == 0:
+ # raise ValueError(f"No detections in sample {det_folder}:{t1}") # empty frames can happen
+ A = np.zeros((0, 0), dtype=bool)
+ else:
+ A = _ctc_assoc_matrix(
+ _labels,
+ _ts,
+ self.gt_graph,
+ matching,
+ )
+
+ w = dict(
+ coords=_coords,
+ t1=t1,
+ # img=self.imgs[t1:t2],
+ # mask=det_masks[t1:t2],
+ assoc_matrix=A,
+ labels=_labels,
+ timepoints=_ts,
+ features=_features,
+ )
+ if not len(_coords) == n_entries or not len(_features) == n_entries:
+ raise ValueError(f"Number of coords {len(_coords)} or features {len(_features)} does not match number of augmentations {n_entries}")
+ windows.append(w)
+
+ logger.debug(f"Built {len(windows)} track windows.\n")
+ return windows
+
+ def _save_windows(self):
+ if self.window_save_path is not None:
+ self._len = len(self.windows)
+ logger.info(f"Saving windows to {self.window_save_path}")
+ mode = "w" if self.force_recompute else "a"
+ root = zarr.open_group(str(self.window_save_path), mode=mode)
+ for i, w in enumerate(self.windows):
+ group_name = f"window_{i}"
+ if group_name in root:
+ del root[group_name]
+ grp = root.create_group(group_name)
+ for aug_id in range(self.n_augs + 1):
+ grp.create_dataset(f"coords_{aug_id}", data=w["coords"][aug_id])
+ features_group = grp.create_group(f"features_{aug_id}")
+ for k, v in w["features"][aug_id].items():
+ features_group.create_dataset(k, data=v)
+ grp.create_dataset("labels", data=w["labels"])
+ grp.create_dataset("timepoints", data=w["timepoints"])
+ grp.create_dataset("assoc_matrix", data=w["assoc_matrix"])
+ grp.attrs["t1"] = w["t1"]
+ else:
+ raise ValueError("No augmented embeddings zarr file set. Cannot save windows.")
+
+ def _load_windows(self):
+ if not self.load_from_disk:
+ if getattr(self, "windows", None) is not None:
+ logger.debug("Windows already loaded into memory.")
+ return self.windows
+ return None
+ if self.window_save_path.exists() and not self.force_recompute:
+ self.windows = []
+ logger.info(f"Loading windows from {self.window_save_path}")
+ root = zarr.open_group(str(self.window_save_path), mode="r")
+ group_names = sorted(
+ root.keys(),
+ key=lambda x: int(x.split("_")[1]) if x.startswith("window_") else x
+ )
+ logger.debug(f"Found {len(group_names)} windows, loading...")
+ for w in group_names:
+ grp = root[w]
+ coords = [grp[f"coords_{aug_id}"][...] for aug_id in range(self.n_augs + 1)]
+ features = {}
+ for aug_id in range(self.n_augs + 1):
+ features[aug_id] = {k: grp[f"features_{aug_id}"][k][...] for k in grp[f"features_{aug_id}"].keys()}
+ labels = grp["labels"][...]
+ timepoints = grp["timepoints"][...]
+ assoc_matrix = grp["assoc_matrix"][...]
+ t1 = grp.attrs["t1"]
+ self.windows.append(dict(
+ coords=coords,
+ features=features,
+ labels=labels,
+ timepoints=timepoints,
+ assoc_matrix=assoc_matrix,
+ t1=t1,
+ ))
+ self._len = len(self.windows)
+ self._get_ndim_and_nobj(None, self.windows)
+ logger.info(f"Loaded {self._len} windows from {self.window_save_path}")
+ return self.windows
+
+ @lru_cache
+ def _sample_from_memory(self, n: int, aug_choice: int = 0):
+ """When self.load_from_disk is False, sample a window from memory."""
+ # logger.debug(f"Sampling window {n} with augmentation choice {aug_choice}")
+ track = self.windows[n]
+ # 0 is original, 1 to n_augs are the augmented versions
+ coords = track["coords"][aug_choice]
+ features = track["features"][aug_choice]
+ assoc_matrix = track["assoc_matrix"]
+ labels = track["labels"]
+ timepoints = track["timepoints"]
+ t1 = track["t1"]
+
+ return coords, features, labels, timepoints, assoc_matrix, t1
+
+ def _sample_from_file(self, window_id: int, aug_choice: int = 0):
+ """When self.load_from_disk is True, sample a window from the saved zarr file."""
+ root = zarr.open_group(str(self.window_save_path), mode="r")
+ grp = root[f"window_{window_id}"]
+ coords = grp[f"coords_{aug_choice}"][...]
+ features = {}
+ for k in grp[f"features_{aug_choice}"].keys():
+ features[k] = grp[f"features_{aug_choice}"][k][...]
+ labels = grp["labels"][...]
+ timepoints = grp["timepoints"][...]
+ assoc_matrix = grp["assoc_matrix"][...]
+ t1 = grp.attrs["t1"]
+ return coords, features, labels, timepoints, assoc_matrix, t1
+
+ def _augment_item(self, item: wrfeat.WRAugPretrainedFeatures, labels, timepoints, assoc_matrix):
+ """Apply augmentations to the features."""
+ # FIXME some arguments are redundant
+ if self.cropper is not None:
+ # Use only if there is at least one timepoint per detection
+ cropped_item, cropped_idx = self.cropper(item)
+ cropped_timepoints = item.timepoints[cropped_idx]
+ if len(np.unique(cropped_timepoints)) == self.window_size:
+ idx = cropped_idx
+ item = cropped_item
+ labels = labels[idx]
+ timepoints = timepoints[idx]
+ assoc_matrix = assoc_matrix[idx][:, idx]
+ # else:
+ # logger.debug("Skipping cropping")
+
+ if self.augmenter is not None:
+ item = self.augmenter(item)
+
+ return item, assoc_matrix
+
+ @lru_cache
+ def get_augmented_image_shape(self, aug_choice: int):
+ try:
+ image_shape = self.augmented_image_shapes[aug_choice]
+ except KeyError:
+ root = zarr.open_group(str(self._aug_embeds_file), mode="r")
+ metadata_json = root[str(aug_choice)].attrs["metadata"]
+ metadata = json.loads(metadata_json)
+ image_shape = metadata["image_shape"]
+ return image_shape
+
+ def __len__(self):
+ if self.save_windows and self.windows is None:
+ return self._len
+ else:
+ return len(self.windows)
+
+ def __getitem__(self, n: int, return_dense=None):
+ if return_dense is None:
+ return_dense = self.return_dense
+
+ random_aug_choice = self._rng.integers(0, self.n_augs + 1)
+
+ if self.load_from_disk:
+ coords, features, labels, timepoints, assoc_matrix, _ = self._sample_from_file(
+ n, random_aug_choice
+ )
+ else:
+ coords, features, labels, timepoints, assoc_matrix, _ = self._sample_from_memory(
+ n, random_aug_choice
+ )
+
+ # if return_dense and isinstance(mask, _CompressedArray):
+ # mask = CTCDataAugPretrainedFeats.decompress(mask)
+ # if return_dense and isinstance(img, _CompressedArray):
+ # img = CTCDataAugPretrainedFeats.decompress(img)
+ if isinstance(assoc_matrix, _CompressedArray):
+ assoc_matrix = CTCDataAugPretrainedFeats.decompress(assoc_matrix)
+
+ coords = np.stack(coords, axis=0)
+ # features = np.stack(features, axis=0)
+
+ augment_wrfeat = wrfeat.WRAugPretrainedFeatures.from_window(
+ features=features,
+ coords=coords,
+ timepoints=timepoints,
+ labels=labels,
+ )
+ augmented_data, assoc_matrix = self._augment_item(augment_wrfeat, labels, timepoints, assoc_matrix)
+ if not isinstance(augmented_data, wrfeat.WRAugPretrainedFeatures):
+ raise ValueError("Augmented data is not a WRAugPretrainedFeatures. Check that augmenter return type is correct.")
+ features, pretrained_features, coords, timepoints, labels = augmented_data.to_window()
+
+ shapes = [
+ len(labels),
+ len(timepoints),
+ len(coords),
+ len(pretrained_features),
+ len(assoc_matrix),
+ ]
+ if features is not None:
+ shapes.append(len(features))
+ if len(np.unique(shapes)) != 1:
+ raise ValueError(f"Shape mismatch: {shapes} (labs/timepoints/coords/features)")
+
+ if coords.shape[-1] != self.ndim + 1:
+ raise ValueError(f"Coords shape mismatch: {coords.shape[-1]} != {self.ndim + 1}")
+
+ # coords is already including time, simply remove min_time along the first axis
+ # coords[:, 0] -= min_time
+
+ if self.max_tokens and len(timepoints) > self.max_tokens:
+ time_incs = np.where(timepoints - np.roll(timepoints, 1))[0]
+ n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1]
+ timepoints = timepoints[:n_elems]
+ labels = labels[:n_elems]
+ coords = coords[:n_elems]
+ if features is not None:
+ features = features[:n_elems]
+ if pretrained_features is not None:
+ pretrained_features = pretrained_features[:n_elems]
+ assoc_matrix = assoc_matrix[:n_elems, :n_elems]
+ logger.info(
+ f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}"
+ )
+
+ coords0 = torch.from_numpy(coords).float()
+ if features is not None:
+ features = torch.from_numpy(features).float()
+ pretrained_features = torch.from_numpy(pretrained_features).float()
+ assoc_matrix = torch.from_numpy(assoc_matrix.copy()).float()
+ labels = torch.from_numpy(labels).long()
+ timepoints = torch.from_numpy(timepoints).long()
+
+ if self.augmenter is not None:
+ coords = coords0.clone()
+ coords[:, 1:] += torch.randint(0, 512, (1, self.ndim))
+ else:
+ coords = coords0.clone()
+
+ if self.rotate_feats:
+ image_shape = self.get_augmented_image_shape(random_aug_choice)
+ pretrained_features = CTCData.rotate_features(
+ pretrained_features, coords, image_shape,
+ n_rot_dims=self.pretrained_feat_dim # // 2
+ )
+
+ res = dict(
+ features=features,
+ pretrained_features=pretrained_features,
+ coords0=coords0,
+ coords=coords,
+ assoc_matrix=assoc_matrix,
+ timepoints=timepoints,
+ labels=labels,
+ )
+
+ # if return_dense:
+ # if all([x is not None for x in img]):
+ # img = torch.from_numpy(img).float()
+ # res["img"] = img
+
+ # mask = torch.from_numpy(mask.astype(int)).long()
+ # res["mask"] = mask
return res
@@ -1354,6 +2396,15 @@ def _assoc(A: np.ndarray, labels: np.ndarray, family: np.ndarray):
A[i, j] = family[i, labels[j]]
+def determine_ctc_class(dataset_kwargs: dict):
+ if "features" not in dataset_kwargs:
+ raise ValueError("features must be set in dataset_kwargs")
+ if dataset_kwargs["features"] == "pretrained_feats_aug":
+ return CTCDataAugPretrainedFeats
+ else:
+ return CTCData
+
+
def _ctc_assoc_matrix(detections, ts, graph, matching):
"""Create the association matrix for a list of labels and a tracklet parent -> childrend graph.
@@ -1471,9 +2522,16 @@ def collate_sequence_padding(batch):
normal_keys = {
"coords": 0,
"features": 0,
+ "pretrained_features": 0,
"labels": 0, # Not needed, remove for speed.
"timepoints": -1, # There are real timepoints with t=0. -1 for distinction from that.
}
+ actual_keys = {
+ k: v for k, v in normal_keys.items() if k in batch[0] and batch[0][k] is not None
+ }
+ none_keys = [
+ k for k in normal_keys.keys() if k in batch[0] and batch[0][k] is None
+ ]
n_pads = tuple(n_max_len - s for s in lens)
batch_new = dict(
(
@@ -1482,8 +2540,10 @@ def collate_sequence_padding(batch):
[pad_tensor(x[k], n_max=n_max_len, value=v) for x in batch], dim=0
),
)
- for k, v in normal_keys.items()
+ for k, v in actual_keys.items()
)
+ for k in none_keys:
+ batch_new[k] = None
batch_new["assoc_matrix"] = torch.stack(
[
pad_tensor(
@@ -1500,6 +2560,8 @@ def collate_sequence_padding(batch):
pad_mask[i, n_max_len - n_pad :] = True
batch_new["padding_mask"] = pad_mask.bool()
+ if torch.all(pad_mask.bool()):
+ raise ValueError("No valid entries for padding mask!")
return batch_new
diff --git a/trackastra/data/distributed.py b/trackastra/data/distributed.py
index 04d0082..c0da0b0 100644
--- a/trackastra/data/distributed.py
+++ b/trackastra/data/distributed.py
@@ -20,26 +20,18 @@
DistributedSampler,
)
-from .data import CTCData
+from .data import CTCData, CTCDataAugPretrainedFeats, determine_ctc_class
+from .utils import make_hashable
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
-def cache_class(cachedir=None):
+def cache_class(dataset_kwargs, cachedir=None):
"""A simple file cache for CTCData."""
- def make_hashable(obj):
- if isinstance(obj, tuple | list):
- return tuple(make_hashable(e) for e in obj)
- elif isinstance(obj, Path):
- return obj.as_posix()
- elif isinstance(obj, dict):
- return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
- else:
- return obj
-
def hash_args_kwargs(*args, **kwargs):
+ # FIXME rotate_features arg should not be part of hash, if it changes cached data does not change
hashable_args = tuple(make_hashable(arg) for arg in args)
hashable_kwargs = make_hashable(kwargs)
combined_serialized = json.dumps(
@@ -49,7 +41,7 @@ def hash_args_kwargs(*args, **kwargs):
return hash_obj.hexdigest()
if cachedir is None:
- return CTCData
+ return determine_ctc_class(dataset_kwargs)
else:
cachedir = Path(cachedir)
@@ -60,11 +52,25 @@ def _wrapped(*args, **kwargs):
if cache_file.exists():
logger.info(f"Loading cached dataset from {cache_file}")
with open(cache_file, "rb") as f:
- return pickle.load(f)
+ c = pickle.load(f)
+ # if c.pretrained_config is not None:
+ # cfg = c.pretrained_config
+ # if cfg.pca_preprocessor_path is not None:
+ # pca = EmbeddingsPCACompression.from_pretrained_cfg(cfg)
+ # pca.load_from_file(cfg.pca_preprocessor_path)
+ # c.pca_preprocessor = pca
+
+ return c
else:
c = CTCData(*args, **kwargs)
+ if c.pretrained_config is not None:
+ c.pretrained_config = c.pretrained_config.to_dict()
+ c.feature_extractor = None
+ if isinstance(c, CTCDataAugPretrainedFeats):
+ c.augmented_feature_extractor = None
logger.info(f"Saving cached dataset to {cache_file}")
pickle.dump(c, open(cache_file, "wb"))
+ logger.debug(f"Cache file size: {cache_file.stat().st_size / 1e6:.2f} MB")
return c
return _wrapped
@@ -156,6 +162,7 @@ def sample_batches(self, idx: Iterable[int]):
# continue
batch = idx_pool[j : j + self.batch_size]
batches.append(batch)
+
return batches
def __iter__(self):
@@ -212,7 +219,7 @@ def __init__(
input_val: list,
cachedir: str,
augment: int,
- distributed:bool,
+ distributed: bool,
dataset_kwargs: dict,
sampler_kwargs: dict,
loader_kwargs: dict,
@@ -232,43 +239,83 @@ def prepare_data(self):
Running on the main CPU process.
"""
- CTCData = cache_class(self.cachedir)
+ CachedData = cache_class(
+ dataset_kwargs=self.dataset_kwargs,
+ cachedir=self.cachedir
+ )
datasets = dict()
+
for split, inps in zip(
("train", "val"),
(self.input_train, self.input_val),
):
logger.info(f"Loading {split.upper()} data")
start = default_timer()
- datasets[split] = torch.utils.data.ConcatDataset(
- CTCData(
+ local_kwargs = deepcopy(self.dataset_kwargs)
+ if self.dataset_kwargs.get("features") == "pretrained_feats_aug" and split == "val":
+ # do not compute augmented pretrained features for the val set
+ local_kwargs["features"] = "pretrained_feats"
+
+ ctc_datasets = [
+ CachedData(
root=Path(inp),
augment=self.augment if split == "train" else 0,
- **self.dataset_kwargs,
+ **local_kwargs,
)
for inp in inps
+ ]
+ [
+ d.feature_extractor_save_path for d in ctc_datasets if split == "train"
+ ]
+ datasets[split] = torch.utils.data.ConcatDataset(
+ ctc_datasets
)
+ del ctc_datasets
logger.info(
f"Loaded {len(datasets[split])} {split.upper()} samples (in"
f" {(default_timer() - start):.1f} s)\n\n"
)
+ # if self.dataset_kwargs.get("pretrained_backbone_config") is not None and split == "train":
+ # cfg = self.dataset_kwargs["pretrained_backbone_config"]
+ # if cfg.pca_preprocessor_path is not None:
+ # pca = EmbeddingsPCACompression.from_pretrained_cfg(cfg)
+ # embeddings_paths = []
+ # for p in feature_extractor_save_paths:
+ # embeddings_paths.append(p)
+ # pca.fit_on_embeddings(embeddings_paths)
+
del datasets
def setup(self, stage: str):
- CTCData = cache_class(self.cachedir)
+ CachedData = cache_class(
+ dataset_kwargs=self.dataset_kwargs,
+ cachedir=self.cachedir
+ )
self.datasets = dict()
+
+ # if self.dataset_kwargs.get("pretrained_backbone_config") is not None:
+ # cfg = self.dataset_kwargs["pretrained_backbone_config"]
+ # if cfg.pca_preprocessor_path is not None:
+ # pca = EmbeddingsPCACompression.from_pretrained_cfg(cfg)
+ # pca.load_from_file(cfg.pca_preprocessor_path)
+ # self.dataset_kwargs["pca_preprocessor"] = pca
+
for split, inps in zip(
("train", "val"),
(self.input_train, self.input_val),
):
logger.info(f"Loading {split.upper()} data")
start = default_timer()
+ local_kwargs = deepcopy(self.dataset_kwargs)
+ if self.dataset_kwargs.get("features") == "pretrained_feats_aug" and split == "val":
+ # do not computea augmented pretrained features for the val set
+ local_kwargs["features"] = "pretrained_feats"
self.datasets[split] = torch.utils.data.ConcatDataset(
- CTCData(
+ CachedData(
root=Path(inp),
augment=self.augment if split == "train" else 0,
- **self.dataset_kwargs,
+ **local_kwargs,
)
for inp in inps
)
@@ -286,7 +333,7 @@ def train_dataloader(self):
)
batch_sampler = None
else:
- sampler=None
+ sampler = None
batch_sampler = BalancedBatchSampler(
self.datasets["train"],
**self.sampler_kwargs,
diff --git a/trackastra/data/features.py b/trackastra/data/features.py
index 0c02f78..b4a489e 100644
--- a/trackastra/data/features.py
+++ b/trackastra/data/features.py
@@ -1,9 +1,13 @@
import itertools
+import logging
import numpy as np
import pandas as pd
from skimage.measure import regionprops_table
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
# the property keys that are supported for 2 and 3 dim
_PROPERTIES = {
diff --git a/trackastra/data/pretrained_augmentations.py b/trackastra/data/pretrained_augmentations.py
new file mode 100644
index 0000000..f56a0f5
--- /dev/null
+++ b/trackastra/data/pretrained_augmentations.py
@@ -0,0 +1,418 @@
+from abc import ABC, abstractmethod
+from collections import OrderedDict
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torchvision import tv_tensors
+from torchvision.transforms import v2 as transforms
+
+from trackastra.utils.utils import percentile_norm
+
+
+class BaseAugmentation(ABC):
+ """Base class for windowed region augmentations."""
+ def __init__(self, p: float = 0.5, rng_seed=None):
+ self._p = p
+ self._rng = np.random.RandomState(rng_seed)
+ self.applied_record = {}
+ self.signature = {}
+
+ def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask):
+ if self._p is None or self._rng.rand() < self._p:
+ aug = self._get_aug()
+ return aug(images, masks)
+ return images, masks
+
+ @abstractmethod
+ def _get_aug(self) -> transforms.Compose:
+ raise NotImplementedError()
+
+
+class FlipAugment(BaseAugmentation):
+ def __init__(self, p_horizontal: float = 0.5, p_vertical: float = 0.5, rng_seed=None):
+ super().__init__(p=None, rng_seed=rng_seed)
+ self._p_horizontal = p_horizontal
+ self._p_vertical = p_vertical
+ self.signature = {
+ "FlipAugment": {
+ "horizontal": self._p_horizontal,
+ "vertical": self._p_vertical
+ }
+ }
+
+ def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask):
+ if self._rng.rand() < self._p_horizontal:
+ images = transforms.functional.hflip(images)
+ masks = transforms.functional.hflip(masks)
+ self.applied_record["hflip"] = True
+ else:
+ self.applied_record["hflip"] = False
+ if self._rng.rand() < self._p_vertical:
+ images = transforms.functional.vflip(images)
+ masks = transforms.functional.vflip(masks)
+ self.applied_record["vflip"] = True
+ else:
+ self.applied_record["vflip"] = False
+ return images, masks
+
+ def _get_aug(self) -> transforms.Compose:
+ raise NotImplementedError("Use __call__ instead.")
+
+
+class RotAugment(BaseAugmentation):
+
+ def __init__(self, p: float = 0.5, degrees: int = 15, rng_seed=None):
+ super().__init__(p, rng_seed=rng_seed)
+ self.degrees = degrees
+ self.signature = {
+ "RotAugment": {
+ "p": self._p,
+ "degrees": self.degrees,
+ }
+ }
+
+ def _get_aug(self):
+ self.applied_record["rotation"] = self.degrees
+ t = transforms.RandomRotation(degrees=self.degrees)
+ return t
+
+
+class Rot90Augment(BaseAugmentation):
+
+ def __init__(self, p=0.5, rng_seed=None):
+ super().__init__(p, rng_seed=rng_seed)
+ self.signature = {
+ "Rot90Augment": {
+ "p": self._p,
+ }
+ }
+
+ def __call__(self, images, masks):
+ if self._rng.rand() > self._p:
+ return images, masks
+ angle = self._get_aug()
+ images = transforms.functional.rotate(images, angle, expand=True)
+ masks = transforms.functional.rotate(masks, angle, expand=True)
+ return images, masks
+
+ def _get_aug(self):
+ angle = self._rng.choice([90, 180, 270])
+ self.applied_record["rot90"] = int(angle)
+ return angle
+
+
+class BrightnessJitter(BaseAugmentation):
+
+ def __init__(self, bright_shift: float = 0.5, contrast_shift: float = 0.5, rng_seed=None):
+ super().__init__(p=None, rng_seed=rng_seed)
+ self._b_shift = bright_shift
+ self._c_shift = contrast_shift
+ self.signature = {
+ "BrightnessJitter": {
+ "brightness_shift": self._b_shift,
+ "contrast_shift": self._c_shift
+ }
+ }
+
+ def _get_aug(self):
+ if self._b_shift is not None:
+ bright = self._rng.uniform(0, self._b_shift)
+ else:
+ bright = None
+ self.applied_record["brightness_jitter"] = bright
+ if self._c_shift is not None:
+ contrast = self._rng.uniform(0, self._c_shift)
+ else:
+ contrast = None
+ self.applied_record["contrast_jitter"] = contrast
+ return transforms.ColorJitter(brightness=bright, contrast=contrast)
+
+
+class AddGaussianNoise(BaseAugmentation):
+ def __init__(self, mean: float = 0.0, std: float = 0.1, rng_seed=None):
+ super().__init__(p=None, rng_seed=rng_seed)
+ self.mean = mean
+ self.sigma = std
+ self.signature = {
+ "AddGaussianNoise": {
+ "mean": self.mean,
+ "std": self.sigma
+ }
+ }
+
+ def _get_aug(self):
+ # sample random mean/std
+ self.applied_record["gaussian_noise"] = (self.mean, self.sigma)
+ return transforms.GaussianNoise(mean=self.mean, sigma=self.sigma)
+
+ def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask):
+ aug = self._get_aug()
+ images = aug(images)
+ return images, masks
+
+
+class GaussianBlur(BaseAugmentation):
+ def __init__(self, kernel_size: int = 3, sigma: tuple[float] = (0.01, 1.0), rng_seed=None):
+ super().__init__(p=None, rng_seed=rng_seed)
+ self.kernel_size = kernel_size
+ self.sigma = sigma
+ self.signature = {
+ "GaussianBlur": {
+ "kernel_size": self.kernel_size,
+ "sigma": self.sigma
+ }
+ }
+
+ def _get_aug(self):
+ self.applied_record["gaussian_blur"] = (self.kernel_size, self.sigma)
+ return transforms.GaussianBlur(kernel_size=self.kernel_size, sigma=self.sigma)
+
+ def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask):
+ aug = self._get_aug()
+ images = aug(images)
+ return images, masks
+
+
+class RandomAffine(BaseAugmentation):
+ def __init__(self, degrees: float = 0.0, translate: tuple[float, float] = (0.0, 0.0), scale: tuple[float, float] = (1.0, 1.0), rng_seed=None):
+ super().__init__(p=None, rng_seed=rng_seed)
+ self.degrees = degrees
+ self.translate = translate
+ self.scale = scale
+ self.signature = {
+ "RandomAffine": {
+ "degrees": self.degrees,
+ "translate": self.translate,
+ "scale": self.scale
+ }
+ }
+
+ def _get_aug(self):
+ return transforms.RandomAffine(degrees=self.degrees, translate=self.translate, scale=self.scale)
+
+ def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask):
+ aug = self._get_aug()
+ images, masks = aug(images, masks)
+ return images, masks
+
+
+class ElasticTransform(BaseAugmentation):
+ def __init__(self, p=0.5, alpha: float = 10.0, sigma: float = 0.5, rng_seed=None):
+ super().__init__(p=p, rng_seed=rng_seed)
+ self.alpha = alpha
+ self.sigma = sigma
+ self.signature = {
+ "ElasticTransform": {
+ "p": self._p,
+ "alpha": self.alpha,
+ "sigma": self.sigma
+ }
+ }
+
+ def _get_aug(self):
+ alpha = self._rng.uniform(0, self.alpha)
+ sigma = self._rng.uniform(0, self.sigma)
+ self.applied_record["elastic_transform"] = (alpha, sigma)
+ return transforms.ElasticTransform(alpha=alpha, sigma=sigma)
+
+
+class RandomScale(BaseAugmentation):
+ def __init__(self, p: float = 0.9, max_scale: float = 1.0, min_scale=0.8, preserve_size=False, rng_seed=None):
+ super().__init__(p=p, rng_seed=rng_seed)
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.preserve_size = preserve_size
+ self.signature = {
+ "RandomScale": {
+ "p": self._p,
+ "min_scale": self.min_scale,
+ "max_scale": self.max_scale,
+ "preserve_size": self.preserve_size
+ }
+ }
+
+ def _get_aug(self):
+ scale = self._rng.uniform(self.min_scale, self.max_scale)
+ self.applied_record["random_scale"] = scale
+ return scale
+
+ def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask):
+ if self._p is None or self._rng.rand() < self._p:
+ scale = self._get_aug()
+ orig_h, orig_w = images.shape[-2], images.shape[-1]
+ new_h, new_w = int(orig_h * scale), int(orig_w * scale)
+
+ # Resize images and masks
+ images_scaled = F.interpolate(images, size=(new_h, new_w), mode="bilinear", align_corners=False)
+ masks_scaled = F.interpolate(masks.float(), size=(new_h, new_w), mode="nearest").long()
+
+ if self.preserve_size:
+ pad_h = max(orig_h - new_h, 0)
+ pad_w = max(orig_w - new_w, 0)
+ pad = [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] # left, right, top, bottom
+
+ images_scaled = F.pad(images_scaled, pad, mode="constant", value=0)
+ masks_scaled = F.pad(masks_scaled, pad, mode="constant", value=0)
+
+ # If scaled image is larger, crop to original size
+ images_scaled = images_scaled[..., :orig_h, :orig_w]
+ masks_scaled = masks_scaled[..., :orig_h, :orig_w]
+
+ return images_scaled, masks_scaled
+ return images, masks
+
+
+class IdentityAugment(BaseAugmentation):
+ """Identity augmentation for debugging purposes."""
+ def __init__(self, p: float = 1.0, rng_seed=None):
+ super().__init__(p=p, rng_seed=rng_seed)
+ self.signature = {
+ "IdentityAugment": {
+ "p": self._p
+ }
+ }
+
+ def _get_aug(self):
+ self.applied_record["identity"] = True
+ return transforms.Lambda(lambda x: x)
+
+ def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask):
+ return images, masks
+
+
+class PretrainedAugmentations:
+ """Augmentation pipeline to get augmented copies of model embeddings."""
+ default_normalize = percentile_norm
+
+ def __init__(self, rng_seed=None, normalize=True, shuffle=True):
+ self.aug_record = {}
+ self.aug_list = [
+ # IdentityAugment(rng_seed=rng_seed), # debugging
+ BrightnessJitter(bright_shift=0.3, contrast_shift=0.3, rng_seed=rng_seed),
+ FlipAugment(p_horizontal=0.5, p_vertical=0.5, rng_seed=rng_seed),
+ # RotAugment(degrees=10, rng_seed=rng_seed),
+ Rot90Augment(p=0.5, rng_seed=rng_seed),
+ # AddGaussianNoise(mean=0.0, std=0.02, rng_seed=rng_seed),
+ RandomScale(rng_seed=rng_seed),
+ GaussianBlur(kernel_size=5, sigma=(0.01, 2.0), rng_seed=rng_seed),
+ # ElasticTransform(p=0.25, alpha=10.0, sigma=0.5, rng_seed=rng_seed),
+ # RandomAffine(degrees=0.0, translate=(0.1, 0.1), scale=(0.9, 1.1), rng_seed=rng_seed),
+ ]
+ self._aug = None
+ self._rng = np.random.RandomState(rng_seed)
+ self.normalize = normalize
+ self.image_shape = None
+ self.shuffle = shuffle
+
+ def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask, normalize_func=None) -> tuple[torch.Tensor, tv_tensors.Mask, dict]:
+ """Applies the augmentations to the images."""
+ images, masks = self.preprocess(images, masks, normalize_func=normalize_func)
+
+ if self.shuffle:
+ aug_list = self.aug_list.copy()
+ self._rng.shuffle(aug_list)
+ else:
+ aug_list = self.aug_list
+ self._aug = transforms.Compose(aug_list)
+
+ images = torch.unsqueeze(images, dim=1) # add channel dimension (T, C, H, W) for augmentation
+ masks = torch.unsqueeze(masks, dim=1) # add channel dimension (T, C, H, W) for augmentation
+
+ images, masks = self._aug(images, masks)
+ if torch.isnan(images).any() or torch.isnan(masks).any():
+ raise RuntimeError("NaN values found in images or masks after augmentation.")
+ self.image_shape = images.shape
+ # NOTE : most models do require 3 channels, but this will be done in FeatureExtractor, so the output is squeezed
+ return images.squeeze(), masks.squeeze(), self.gather_records()
+
+ def get_signature(self):
+ """Returns the signature of the augmentations."""
+ if self._aug is None:
+ self._aug = transforms.Compose(self.aug_list)
+ signatures = OrderedDict()
+ for aug in self.aug_list:
+ if aug.signature:
+ signatures.update(aug.signature)
+ return signatures
+
+ def __add__(self, other):
+ """Combines two augmentation pipelines."""
+ if not isinstance(other, PretrainedAugmentations):
+ raise TypeError("Can only combine with another PretrainedAugmentations instance.")
+ combined = PretrainedAugmentations(rng_seed=self._rng.seed, normalize=self.normalize, shuffle=self.shuffle)
+ combined.aug_list = self.aug_list + other.aug_list
+ return combined
+
+ def __repr__(self):
+ sig = self.get_signature()
+ msg = "Augmentation pipeline"
+ for aug_name, params in sig.items():
+ msg += f"- {aug_name}: {params}\n"
+ msg += "_" * 40 + "\n"
+ return msg.strip()
+
+ def preprocess(self, images, masks, normalize_func=None):
+ if not len(images.shape) == 3:
+ raise ValueError(f"Images must be tensor of shape (T, H, W), got {len(images.shape)}D tensor.")
+ if not len(masks.shape) == 3:
+ raise ValueError(f"Masks must be tensor of shape (T, H, W), got {len(masks.shape)}D tensor.")
+
+ if not isinstance(images, torch.Tensor):
+ try:
+ images = torch.tensor(images, dtype=torch.float32)
+ except Exception as e:
+ raise ValueError(f"Failed to convert images to tensor: {e}")
+ if not isinstance(masks, tv_tensors.Mask):
+ try:
+ masks = tv_tensors.Mask(masks, dtype=torch.int64)
+ except Exception as e:
+ raise ValueError(f"Failed to convert masks to tensor: {e}")
+
+ if normalize_func is not None:
+ if not callable(normalize_func):
+ raise ValueError("normalize_func must be a callable function.")
+ images = normalize_func(images)
+
+ return images, masks
+
+ def gather_records(self):
+ """Gathers the applied augmentation records."""
+ self.aug_record = {}
+ for aug in self.aug_list:
+ self.aug_record.update(aug.applied_record)
+ self.aug_record["image_shape"] = self.image_shape
+ return self.aug_record
+
+
+class PretrainedMovementAugmentations(PretrainedAugmentations):
+ """Augmentation pipeline for movement embeddings."""
+ def __init__(self, rng_seed=None, normalize=True, shuffle=True):
+ super().__init__(rng_seed=rng_seed, normalize=normalize, shuffle=shuffle)
+ self.aug_list = [
+ FlipAugment(p_horizontal=0.5, p_vertical=0.5, rng_seed=rng_seed),
+ Rot90Augment(p=0.5, rng_seed=rng_seed),
+ RandomScale(rng_seed=rng_seed),
+ ]
+
+
+class PretrainedIntensityAugmentations(PretrainedAugmentations):
+ """Augmentation pipeline for intensity embeddings."""
+ def __init__(self, rng_seed=None, normalize=True, shuffle=True):
+ super().__init__(rng_seed=rng_seed, normalize=normalize, shuffle=shuffle)
+ self.aug_list = [
+ BrightnessJitter(bright_shift=0.05, contrast_shift=0.05, rng_seed=rng_seed),
+ # FlipAugment(p_horizontal=0.5, p_vertical=0.5, rng_seed=rng_seed),
+ # Rot90Augment(p=0.5, rng_seed=rng_seed),
+ AddGaussianNoise(mean=0.0, std=0.02, rng_seed=rng_seed),
+ # RandomScale(rng_seed=rng_seed),
+ ]
+
+
+class IdentityAugmentations(PretrainedAugmentations):
+ """Identity augmentation pipeline for debugging."""
+ def __init__(self, rng_seed=None, normalize=True, shuffle=True):
+ super().__init__(rng_seed=rng_seed, normalize=normalize, shuffle=shuffle)
+ self.aug_list = [
+ IdentityAugment(p=1.0, rng_seed=rng_seed),
+ ]
\ No newline at end of file
diff --git a/trackastra/data/pretrained_features.py b/trackastra/data/pretrained_features.py
new file mode 100644
index 0000000..0165967
--- /dev/null
+++ b/trackastra/data/pretrained_features.py
@@ -0,0 +1,2122 @@
+import json
+import logging
+import os
+from abc import ABC, abstractmethod
+from dataclasses import asdict, dataclass
+from functools import partial
+from pathlib import Path
+from typing import TYPE_CHECKING, Literal
+
+import joblib
+import numpy as np
+import torch
+import torch.nn.functional as F
+import zarr
+from numcodecs import Blosc
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+from skimage.measure import regionprops
+from tqdm import tqdm
+from transformers import (
+ AutoImageProcessor,
+ # Dinov2Config,
+ # Dinov2Model,
+ AutoModel,
+ HieraConfig,
+ HieraModel,
+ SamModel,
+ SamProcessor,
+)
+
+from trackastra.data import wrfeat
+from trackastra.utils.utils import percentile_norm
+
+if TYPE_CHECKING:
+ from trackastra.data.pretrained_augmentations import PretrainedAugmentations
+
+try:
+ from micro_sam.util import get_sam_model as get_microsam_model
+ MICRO_SAM_AVAILABLE = True
+except ImportError:
+ MICRO_SAM_AVAILABLE = False
+
+try:
+ from tarrow.models import TimeArrowNet
+ from tarrow.utils import normalize as tap_normalize
+ TARROW_AVAILABLE = True
+except ImportError:
+ TARROW_AVAILABLE = False
+
+try:
+ from cellpose import transforms as cp_transforms
+ from cellpose.vit_sam import Transformer as CellposeSAM
+ CELLPOSE_AVAILABLE = True
+except ImportError:
+ CELLPOSE_AVAILABLE = False
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+# Updated with actual class after each definition
+# See register_backbone decorator
+AVAILABLE_PRETRAINED_BACKBONES = {}
+
+PretrainedFeatsExtractionMode = Literal[
+ # "exact_patch", # Uses the image patch centered on the detection for embedding
+ "nearest_patch", # Runs on whole image, then finds the nearest patch to the detection in the embedding
+ "mean_patches_bbox", # Runs on whole image, then averages the embeddings of all patches that intersect with the detection's bounding box
+ "mean_patches_exact", # Runs on whole image, then averages the embeddings of all patches that intersect with the detection
+ "max_patches_bbox", # Runs on whole image, then takes the maximum for each feature dimension of all patches that intersect with the detection
+ "max_patches_exact", # Runs on whole image, then takes the maximum for each feature dimension of all patches that intersect with the detection
+ "median_patches_exact", # Runs on whole image, then takes the median for each feature dimension of all patches that intersect with the detection
+]
+
+PretrainedBackboneType = Literal[ # cannot unpack this directly in python < 3.11 so it has to be copied
+ "facebook/hiera-tiny-224-hf", # 768
+ "facebook/dinov2-base", # 768
+ "facebook/sam-vit-base", # 256
+ "facebook/sam2-hiera-large", # 256
+ "facebook/sam2.1-hiera-base-plus", # 256
+ "facebookresearch/co-tracker", # 128
+ "microsam/vit_b_lm",
+ "microsam/vit_l_lm",
+ "weigertlab/tarrow", # arbitrary. default 32
+ "mouseland/cellpose-sam", # 192
+ "facebook/sam2.1-hiera-base-plus/highres",
+ "debug/random",
+ "debug/encoded_labels", # 32
+]
+
+
+def register_backbone(model_name, feat_dim):
+ def decorator(cls):
+ AVAILABLE_PRETRAINED_BACKBONES[model_name] = {
+ "class": cls,
+ "feat_dim": feat_dim,
+ }
+ return cls
+ return decorator
+
+
+# Feature extraction from pretrained models
+# Meant to wrap any transformers model
+# >NOTE : currently not applicable to 3D data
+# (but aggregation-based modes may be adapted eventually)
+import time
+from functools import wraps
+
+
+def average_time_decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if not hasattr(wrapper, 'total_time'):
+ wrapper.total_time = 0
+ wrapper.call_count = 0
+
+ start_time = time.time()
+ result = func(*args, **kwargs)
+ elapsed_time = time.time() - start_time
+
+ wrapper.total_time += elapsed_time
+ wrapper.call_count += 1
+ average_time = wrapper.total_time / wrapper.call_count
+
+ print(f"Average time taken by {func.__name__}: {average_time:.6f} seconds over {wrapper.call_count} calls")
+
+ return result
+
+ return wrapper
+
+
+# Configs for pretrained models ###
+
+@dataclass
+class PretrainedFeatureExtractorConfig:
+ """model_name (str):
+ Specify the pretrained backbone to use.
+ model_path (str | Path):
+ Path to the pretrained model.
+ save_path (str | Path):
+ Specify the path to save the embeddings.
+ batch_size (int):
+ Specify the batch size to use for the model.
+ mode (str):
+ Specify the mode to use for the model.
+ Currently available modes are "nearest_patch", "mean_patches_bbox", "mean_patches_exact", "max_patches_bbox", "max_patches_exact".
+ normalize_embeddings (bool):
+ Whether to normalize the embeddings (divide by the norm).
+ device (str):
+ Specify the device to use for the model.
+ If not set and "pretrained_feats" is used, the device is automatically set by default to "cuda", "mps" or "cpu" as available.
+ n_augmented_copies (int):
+ How many augmented copies of the embeddings to create. If 0, only the original embeddings are saved. Creates n+1 embeddings entries total.
+ additional_features (str):
+ Specify any additional features (from regionprops) to include in the extraction process. See WRFeat documentation for available features. Unused if None.
+ pca_components (int):
+ Specify the number of PCA components to use for dimensionality reduction of the features. Unused if None.
+ pca_preprocessor_path (str | Path):
+ Specify the path to the pickled PCA preprocessor. This is used to transform the features to a reduced PCA feature space.
+ """
+ model_name: PretrainedBackboneType
+ model_path: str | Path = None
+ save_path: str | Path = None
+ batch_size: int = 4
+ mode: PretrainedFeatsExtractionMode = "nearest_patch"
+ normalize_embeddings: bool = True # whether to normalize the embeddings (divide by the norm)
+ device: str | None = None
+ feat_dim: int = None
+ additional_features: str | None = None # for regionprops features
+ additional_feat_dim: int = 0 # for regionprops features
+ n_augmented_copies: int = 0 # number of augmented copies to create
+ seed: int | None = None # seed for debug/random
+ # pca_components: int = None # for PCA reduction of the features
+ # pca_preprocessor_path: str | Path = None # for PCA preprocessor path
+ # apply_rope: bool = False # whether to apply RoPE-like rotation to the features based on coordinates
+
+ def __post_init__(self):
+ self._guess_device()
+ self.model_path = self._check_path(self.model_path)
+ self.save_path = self._check_path(self.save_path)
+ # self.pca_preprocessor_path = self._check_path(self.pca_preprocessor_path)
+ self._check_model_availability()
+
+ def _check_path(self, path):
+ if path is not None and not isinstance(path, str | Path):
+ raise ValueError(f"Path must be a string or Path object, got {type(path)}.")
+ if isinstance(path, str):
+ return Path(path).resolve()
+ return path
+
+ def _check_model_availability(self):
+ if self.model_name not in AVAILABLE_PRETRAINED_BACKBONES.keys():
+ raise ValueError(f"Model {self.model_name} is not available for feature extraction.")
+ if self.model_name == "weigertlab/tarrow":
+ if not TARROW_AVAILABLE:
+ raise ImportError("TArrow is not available. Please install it to use this model.")
+ elif self.model_path is None:
+ raise ValueError("Model path must be specified for TArrow.")
+ _, self.feat_dim = TAPFeatures._load_model_from_path(self.model_path)
+ else:
+ self.feat_dim = AVAILABLE_PRETRAINED_BACKBONES[self.model_name]["feat_dim"]
+ if self.additional_features is not None:
+ # TODO if this ever accepts 3D data this will be incorrect
+ self.additional_feat_dim = wrfeat.WRFeatures.PROPERTIES_DIMS[
+ self.additional_features
+ ][2]
+ if self.additional_features not in wrfeat._PROPERTIES:
+ raise ValueError(f"Additional feature {self.additional_features} is not valid.")
+ # if self.pca_components is not None:
+ # self.feat_dim = self.pca_components
+
+ def _guess_device(self):
+ if self.device is None:
+ should_use_mps = (
+ torch.backends.mps.is_available()
+ and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
+ and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
+ )
+ self.device = (
+ "cuda"
+ if torch.cuda.is_available()
+ else (
+ "mps"
+ if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK")
+ else "cpu"
+ )
+ )
+
+ try:
+ torch.device(self.device) # check if device is valid
+ except Exception as e:
+ raise ValueError(f"Invalid device: {self.device}") from e
+
+ def to_dict(self):
+ return asdict(self)
+
+ @classmethod
+ def from_dict(cls, config_dict):
+ return cls(**config_dict)
+
+# Feature extractors ###
+
+
+class FeatureExtractor(ABC):
+ model_name = None
+ _available_backbones = None
+
+ def __init__(
+ self,
+ image_size: tuple[int, int],
+ save_path: str | Path,
+ batch_size: int = 4,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ mode: PretrainedFeatsExtractionMode = "nearest_patch",
+ normalize_embeddings: bool = True,
+ **kwargs,
+ ):
+ """
+ Initializes a pretrained model feature extractor with the given parameters.
+ Consumes images, computes embeddings of shape T, H*W, N (N the model feature dimension),
+ and generates n_regions x N embeddings for each object in each frame using the specified mode.
+
+ Args:
+ - image_size (tuple[int, int]): Size of the input images (height, width).
+ - save_path (str | Path): Path to save the embeddings.
+ - batch_size (int): Batch size to use for the model.
+ - device (str): Device to use for the model. Defaults to "cuda" if available, otherwise "cpu".
+ - mode (str): Mode to use for the model. Defaults to "nearest_patch". See type PretrainedFeatsExtractionMode for available modes.
+ - normalize_embeddings (bool): Whether to normalize the embeddings (divide by the norm). Defaults to True. For aggregation-based modes, this is applied before aggregation.
+ """
+ # Image processor extra args
+ # Modify as needed in subclasses
+ self.im_proc_kwargs = {
+ "do_rescale": False,
+ "do_normalize": False,
+ "do_resize": True,
+ "return_tensors": "pt",
+ "do_center_crop": False,
+ }
+ # Model specs
+ self.model = None
+ self._input_size: tuple[int] = None
+ self._final_grid_size: tuple[int] = None
+ self.n_channels: int = None
+ self.hidden_state_size: int = None
+ self.model_patch_size: int = None
+ # Data specs
+ self.orig_image_size = image_size
+ self.orig_n_channels = 1
+ # Batch options and preprocessing
+ self.do_normalize = True
+ self.rescale_batches = False
+ self.channel_first = True
+ self.batch_return_type: Literal["list[np.ndarray]", "np.ndarray", "torch.Tensor"] = "np.ndarray"
+ self.batch_size = batch_size
+ self.device = device
+ # Parameters for embedding extraction
+ self.mode = mode
+ # If running FeatureExtractor in a parallelized context, set to False to avoid overhead
+ # from spawning many threads within the parallelized context.
+ self.parallel = True
+ self.additional_features = None
+ self.apply_rope = False # deprecated, use "rotate_features" in CTCData # TODO remove
+ self.normalize_embeddings = normalize_embeddings
+ # Saving parameters
+ self.save_path: str | Path = save_path
+ self.do_save = True
+ self.force_recompute = False
+
+ self.embeddings = None
+ self._debug_view = False
+ # self._debug = True
+
+ if not isinstance(self.save_path, Path):
+ self.save_path = Path(self.save_path)
+
+ if not self.save_path.exists():
+ self.save_path.mkdir(parents=True, exist_ok=True)
+
+ @property
+ def input_size(self):
+ return self._input_size
+
+ @input_size.setter
+ def input_size(self, value: int | tuple[int]):
+ if isinstance(value, int):
+ value = (value, value)
+ elif isinstance(value, tuple):
+ if len(value) != 2:
+ raise ValueError("Input size must be a tuple of length 2.")
+ else:
+ raise ValueError("Input size must be an int or a tuple of ints.")
+ self._input_size = value
+ self._set_model_patch_size()
+
+ @property
+ def final_grid_size(self):
+ return self._final_grid_size
+
+ @final_grid_size.setter
+ def final_grid_size(self, value: int | tuple[int]):
+ if isinstance(value, int):
+ value = (value, value)
+ elif isinstance(value, tuple):
+ if len(value) != 2:
+ raise ValueError("Final grid size must be a tuple of length 2.")
+ else:
+ raise ValueError("Final grid size must be an int or a tuple of ints.")
+ self._final_grid_size = value
+ self._set_model_patch_size()
+
+ @property
+ def model_name_path(self):
+ return self.model_name.replace("/", "-")
+
+ @staticmethod
+ def _load_model_from_path(self) -> tuple[torch.nn.Module, int]:
+ """Loads the model from the specified path. Returns the model and the model feature dimension (e.g. 256 for SAM2)."""
+ raise NotImplementedError("This model currently only supports being loaded from huggingface's hub.")
+
+ @classmethod
+ def from_model_name(cls,
+ model_name: PretrainedBackboneType,
+ image_shape: tuple[int, int],
+ save_path: str | Path,
+ device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
+ mode="nearest_patch",
+ additional_features=None,
+ model_folder=None,
+ ):
+ cls._available_backbones = AVAILABLE_PRETRAINED_BACKBONES
+ if model_name not in cls._available_backbones:
+ raise ValueError(f"Model {model_name} is not available for feature extraction.")
+ logger.info(f"Using model {model_name} with mode {mode} for pretrained feature extraction.")
+ backbone = cls._available_backbones[model_name]["class"]
+ backbone.model_name = model_name
+ model = backbone(
+ image_size=image_shape,
+ save_path=save_path,
+ device=device,
+ mode=mode,
+ model_folder=model_folder,
+ )
+ model.additional_features = additional_features
+ return model
+
+ @classmethod
+ def from_config(cls, config: PretrainedFeatureExtractorConfig, image_shape: tuple[int, int], save_path: str | Path | None = None):
+ cls._available_backbones = AVAILABLE_PRETRAINED_BACKBONES
+ if config.model_name not in cls._available_backbones:
+ raise ValueError(f"Model {config.model_name} is not available for feature extraction.")
+ logger.info(f"Using model {config.model_name} with mode {config.mode} for pretrained feature extraction.")
+ backbone = cls._available_backbones[config.model_name]["class"]
+
+ parts = config.model_name.split("/")
+ if len(parts) > 2:
+ model_name = "/".join(parts[:2])
+ else:
+ model_name = config.model_name
+ backbone.model_name = model_name
+
+ model = backbone(
+ image_size=image_shape,
+ save_path=save_path if save_path is not None else config.save_path,
+ batch_size=config.batch_size,
+ device=config.device,
+ mode=config.mode,
+ additional_features=config.additional_features,
+ model_folder=config.model_path,
+ normalize_embeddings=config.normalize_embeddings,
+ seed=config.seed if hasattr(config, "seed") else None,
+ # n_augmented_copies=config.n_augmented_copies,
+ # aug_pipeline=PretrainedAugmentations() if config.n_augmented_copies > 0 else None,
+ )
+ model.additional_features = config.additional_features
+ model.normalize_embeddings = config.normalize_embeddings
+ # model.apply_rope = config.apply_rope
+ return model
+
+ def clear_model(self):
+ """Clears the model from memory."""
+ if self.model is not None:
+ del self.model
+ self.model = None
+ torch.cuda.empty_cache()
+ logger.info("Model cleared from memory.")
+ else:
+ logger.warning("No model to clear from memory.")
+
+ def _set_model_patch_size(self):
+ if self.final_grid_size is None or self.input_size is None:
+ self.model_patch_size = None
+ else:
+ if not isinstance(self.input_size, tuple):
+ raise ValueError("Input size must be a tuple of ints.")
+ self.model_patch_size = (
+ self.input_size[0] // self.final_grid_size[0],
+ self.input_size[1] // self.final_grid_size[1],
+ )
+ if self.model_patch_size[0] <= 0 or self.model_patch_size[1] <= 0:
+ raise ValueError("Model patch size must be greater than 0.")
+
+ def compute_region_features(
+ self,
+ coords,
+ masks=None,
+ timepoints=None,
+ labels=None,
+ # embeddings=None
+ ) -> torch.Tensor:
+ feats = torch.zeros(len(coords), self.hidden_state_size, device=self.device)
+ match self.mode:
+ case "nearest_patch":
+ feats = self._nearest_patches(coords, masks, norm=self.normalize_embeddings)
+ return feats # Return early, nothing else to do
+ case mode if mode.endswith("_patches_exact"):
+ if masks is None or labels is None or timepoints is None:
+ raise ValueError("Masks and labels must be provided for the chosen patch mode.")
+ feats_func = partial(self._agg_patches_exact, masks, timepoints, labels, norm=self.normalize_embeddings)
+ case mode if mode.endswith("_patches_bbox"):
+ if masks is None or labels is None or timepoints is None:
+ raise ValueError("Masks and labels must be provided for the chosen patch mode.")
+ feats_func = partial(self._agg_patches_bbox, masks, timepoints, labels, norm=self.normalize_embeddings)
+ case _:
+ raise NotImplementedError(f"Mode {self.mode} is not implemented.")
+
+ # Only for aggregation modes
+ if "max" in self.mode:
+ feats = feats_func(agg=torch.max)
+ elif "mean" in self.mode:
+ feats = feats_func(agg=torch.mean)
+ elif "median" in self.mode:
+ feats = feats_func(agg=torch.median)
+ else:
+ raise NotImplementedError(f"Unknown aggregation for mode {self.mode}")
+
+ assert feats.shape == (len(coords), self.hidden_state_size)
+ return feats # (n_regions, embedding_size)
+
+ def precompute_image_embeddings(self, images, **kwargs): # , windows, window_size):
+ """Precomputes embeddings for all images."""
+ missing = self._check_missing_embeddings()
+ all_embeddings = torch.zeros(len(images), self.final_grid_size[0] * self.final_grid_size[1], self.hidden_state_size, device=self.device)
+ if missing:
+ for ts, batches in tqdm(self._prepare_batches(images), total=len(images) // self.batch_size, desc="Computing embeddings", leave=False):
+ embeddings = self._run_model(batches, **kwargs)
+ if torch.any(embeddings.isnan()):
+ raise RuntimeError("NaN values found in features.")
+ # logger.debug(f"Embeddings shape: {embeddings.shape}")
+ all_embeddings[ts] = embeddings.to(torch.float32)
+ assert embeddings.shape[-1] == self.hidden_state_size
+ self.embeddings = all_embeddings
+ self._save_features(all_embeddings)
+ # logger.debug(f"Precomputed embeddings shape: {self.embeddings.shape}")
+ return self.embeddings
+
+ def _extract_region_embeddings(self, all_frames_embeddings, window, start_index, remaining=None):
+ window_coords = window["coords"]
+ window_timepoints = window["timepoints"]
+ window_masks = window["mask"]
+ window_labels = window["labels"]
+
+ n_regions_per_frame, features = self.extract_embedding(window_masks, window_timepoints, window_labels, window_coords)
+
+ for i in range(remaining or len(n_regions_per_frame)):
+ # if computing remaining frames' embeddings, start from the end
+ obj_per_frame = n_regions_per_frame[-i - 1] if remaining else n_regions_per_frame[i]
+ frame_index = start_index + i if not remaining else np.max(window_timepoints) - i
+ # logger.debug(f"Frame {frame_index} has {obj_per_frame} objects.")
+ all_frames_embeddings[frame_index] = features[:obj_per_frame]
+ features = features[obj_per_frame:]
+
+ def extract_embedding(self, masks, timepoints, labels, coords):
+ # if masks.shape[-2:] != self.orig_image_size:
+ # This should not be occuring since each folder is loaded as a separate CTCData
+ # However when computing augmented embeddings in parallel, the input size may change
+ # logger.debug(f"Input shape change detected: {masks.shape[-2:]} from {self.orig_image_size}.")
+ # self.orig_image_size = masks.shape[-2:]
+ n_regions_per_frame = np.unique(timepoints, return_counts=True)[1]
+ tot_regions = n_regions_per_frame.sum()
+ coords_txy = np.concatenate((timepoints[:, None], coords), axis=-1)
+ if coords_txy.shape[0] != tot_regions:
+ raise RuntimeError(f"Number of coords ({coords_txy.shape[0]}) does not match the number of coordinates ({timepoints.shape[0]}).")
+ features = self.compute_region_features(
+ masks=masks,
+ coords=coords_txy,
+ timepoints=timepoints,
+ labels=labels,
+ )
+ if torch.isnan(features).any():
+ raise RuntimeError("NaN values found in features.")
+ if tot_regions != features.shape[0]:
+ raise RuntimeError(f"Number of regions ({n_regions_per_frame}) does not match the number of embeddings ({features.shape[0]}).")
+ return n_regions_per_frame, features
+
+ @abstractmethod
+ def _run_model(self, images, **kwargs) -> torch.Tensor: # must return (B, grid_size**2, hidden_state_size)
+ """Extracts embeddings from the model."""
+ pass
+
+ def normalize_array(self, b):
+ b = percentile_norm(b)
+ if self.rescale_batches:
+ b = b * 255.0
+ return b
+
+ @staticmethod
+ def get_centroids_from_masks(masks: np.ndarray) -> np.ndarray:
+ """Computes the centroids of the objects in the masks.
+
+ Args:
+ masks: (n_objects, H, W) array of masks.
+
+ Returns:
+ Centroids: (n_objects, 2) array of (y, x) centroid coordinates, normalized to [0, 1].
+ """
+ centroids_df = regionprops(masks)
+ centroids = np.array([region.centroid for region in centroids_df])
+ centroids[:, 0] = centroids[:, 0] / masks.shape[1]
+ centroids[:, 1] = centroids[:, 1] / masks.shape[2]
+ return centroids
+
+ def apply_rot_to_features(self, features: torch.Tensor, centroids: np.ndarray) -> torch.Tensor:
+ """Applies a rotation to each feature vector based on the object's centroid.
+
+ Args:
+ features: (n_objects, hidden_state_size) tensor of features.
+ centroids: (n_objects, 2) array of (y, x) centroid coordinates, normalized to [0, 1].
+
+ Returns:
+ Rotated features: (n_objects, hidden_state_size)
+ """
+ n_objects, d = features.shape
+ assert d % 2 == 0, "Feature dimension must be even for rotation."
+ angle_x = torch.from_numpy(2 * np.pi * centroids[:, 0]).to(features.device)
+ angle_y = torch.from_numpy(2 * np.pi * centroids[:, 1]).to(features.device)
+
+ angles = torch.stack([angle_x, angle_y], dim=1).repeat(1, d // 2)
+ angles = angles.view(n_objects, d)
+ cos = torch.cos(angles)
+ sin = torch.sin(angles)
+ features_ = features.view(n_objects, -1, 2)
+ x_feat, y_feat = features_[..., 0], features_[..., 1]
+ x_rot = x_feat * cos[:, ::2] - y_feat * sin[:, ::2]
+ y_rot = x_feat * sin[:, ::2] + y_feat * cos[:, ::2]
+ rotated = torch.stack([x_rot, y_rot], dim=-1).reshape(n_objects, d)
+ if torch.allclose(rotated, features):
+ logger.warning("Rotated features are equal to original features. Rotation may not be applied correctly.")
+ return rotated
+
+ def _prepare_batches(self, images):
+ """Prepares batches of images for embedding extraction."""
+ if self.do_normalize:
+ images = self.normalize_array(images)
+ if self.rescale_batches:
+ images = images * 255.0
+ for i in range(0, len(images), self.batch_size):
+ end = i + self.batch_size
+ end = min(end, len(images))
+ batch = np.expand_dims(images[i:end], axis=1) # (B, C, H, W)
+
+ timepoints = range(i, end)
+ if self.n_channels > 1: # repeat channels if needed
+ if self.orig_n_channels > 1 and self.orig_n_channels != self.n_channels:
+ raise ValueError("When more than one original channel is provided, the number of channels in the model must match the number of channels in the input.")
+ batch = np.repeat(batch, self.n_channels, axis=1)
+ if not self.channel_first:
+ batch = np.moveaxis(batch, 1, -1)
+ if self.batch_return_type == "list[np.ndarray]":
+ batch = list([im for im in batch])
+ yield timepoints, batch
+
+ @staticmethod
+ def normalize_tensor(embeddings: torch.Tensor, norm: bool = True) -> torch.Tensor:
+ """Normalizes the embeddings by dividing by the norm."""
+ if norm:
+ embeddings = embeddings / (embeddings.norm(dim=-1, keepdim=True) + 1e-8)
+ return embeddings
+
+ def _map_coords_to_model_grid(self, coords):
+ scale_x = self.input_size[0] / self.orig_image_size[0]
+ scale_y = self.input_size[1] / self.orig_image_size[1]
+ coords = np.array(coords)
+ patch_x = (coords[:, 1] * scale_x).astype(int)
+ patch_y = (coords[:, 2] * scale_y).astype(int)
+ patch_coords = np.column_stack((coords[:, 0], patch_x, patch_y))
+ return patch_coords
+
+ def _find_nearest_cell(self, patch_coords):
+ """Finds the nearest cell in the grid for each patch coordinate."""
+ x_idxs = patch_coords[:, 1] // self.model_patch_size[0]
+ y_idxs = patch_coords[:, 2] // self.model_patch_size[1]
+ patch_idxs = np.column_stack((patch_coords[:, 0], x_idxs, y_idxs)).astype(int)
+ return patch_idxs
+
+ def _find_bbox_cells(self, regions: dict, cell_height: int, cell_width: int):
+ """Finds the cells in a grid that a bounding box belongs to.
+
+ Args:
+ - regions (dict): Dictionary from regionprops. Must contain bbox.
+ - cell_height (int): Height of a cell in the grid.
+ - cell_width (int): Width of a cell in the grid.
+
+ Returns:
+ - tuple: A tuple containing the grid cell indices that the bounding box intersects.
+ """
+ mask_patches = {}
+ for region in regions:
+ minr, minc, maxr, maxc = region.bbox
+ patches = self._find_region_cells(minr, minc, maxr, maxc, cell_height, cell_width)
+ mask_patches[region.label] = patches
+
+ return mask_patches
+
+ @staticmethod
+ def _find_region_cells(minr, minc, maxr, maxc, cell_height, cell_width):
+ start_patch_y = minr // cell_height
+ end_patch_y = (maxr - 1) // cell_height
+ start_patch_x = minc // cell_width
+ end_patch_x = (maxc - 1) // cell_width
+ patches = np.array([(i, j) for i in range(start_patch_y, end_patch_y + 1) for j in range(start_patch_x, end_patch_x + 1)])
+ return patches
+
+ def _find_patches_for_masks(self, image_mask: np.ndarray) -> dict:
+ """Find which patches in a grid each mask belongs to using regionprops.
+
+ Args:
+ - image_masks (np.ndarray): Masks where each region has a unique label.
+
+ Returns:
+ - mask_patches (dict): Dictionary with region labels as keys and lists of patch indices as values.
+ """
+ patch_height = image_mask.shape[0] // self.final_grid_size[0]
+ patch_width = image_mask.shape[1] // self.final_grid_size[1]
+ regions = regionprops(image_mask)
+ return self._find_bbox_cells(regions, patch_height, patch_width)
+
+ def _debug_show_patches(self, embeddings, masks, coords, patch_idxs):
+ import napari
+ v = napari.Viewer()
+ # v.add_labels(masks)
+ e = embeddings.detach().cpu().numpy().swapaxes(1, 2)
+ e = e.reshape(-1, self.hidden_state_size, self.final_grid_size[0], self.final_grid_size[1]).swapaxes(0, 1)
+
+ v.add_image(
+ e,
+ name="Embeddings",
+ )
+ # add red points at patch indices for the relevant frame
+ points = np.zeros((len(patch_idxs) * self.hidden_state_size, 3))
+ for i, (t, y, x) in enumerate(patch_idxs):
+ point = np.array([t, y, x])
+ points[i * self.hidden_state_size:(i + 1) * self.hidden_state_size] = np.tile(point, (self.hidden_state_size, 1))
+
+ v.add_points(points, size=1, face_color='red', name='Patch Indices')
+
+ from skimage.transform import resize
+ masks_resized = resize(masks[0], (self.final_grid_size[0], self.final_grid_size[1]), anti_aliasing=False, order=0, preserve_range=True)
+ v.add_labels(masks_resized)
+ logger.debug(f"Lost labels : {set(np.unique(masks)) - set(np.unique(masks_resized))}")
+
+ napari.run()
+
+ def _nearest_patches(self, coords, masks=None, norm=True, embs=None):
+ """Finds the nearest patches to the detections in the embedding."""
+ # find coordinate patches from detections
+ patch_coords = self._map_coords_to_model_grid(coords)
+ patch_idxs = self._find_nearest_cell(patch_coords)
+ # logger.debug(f"Patch indices: {patch_idxs}")
+
+ # load the embeddings and extract the relevant ones
+ feats = torch.zeros(len(coords), self.hidden_state_size, device=self.device)
+ indices = [y * self.final_grid_size[1] + x for _, y, x in patch_idxs]
+ unique_timepoints = list(set(t for t, _, _ in patch_idxs))
+ # logger.debug(f"Unique timepoints: {unique_timepoints}")
+ embeddings = self._load_features() if embs is None else embs
+
+ # try:
+ # t = coords[0][0]
+ # if t == 3:
+ # self._debug_show_patches(embeddings, masks, coords, patch_idxs)
+ # except IndexError:
+ # logger.debug("No timepoint found in coords.")
+
+ # logger.debug(f"Embeddings shape: {embeddings.shape}")
+ embeddings_dict = {t: embeddings[t] for t in unique_timepoints}
+ try:
+ for i, (t, _, _) in enumerate(patch_idxs):
+ feats[i] = embeddings_dict[t][indices[i]]
+ except KeyError as e:
+ logger.error(f"KeyError: {e} - Check if the timepoint exists in embeddings_dict.")
+ except IndexError as e:
+ # TODO improve handling of this error. Maybe check shape earlier
+ logger.error(f"IndexError: {e} - Embeddings exist but do not have the correct shape. Did the model input size change ? If so, please delete saved embeddings and recompute.")
+ feats = FeatureExtractor.normalize_tensor(feats, norm=norm)
+ if self.apply_rope:
+ centroids = FeatureExtractor.get_centroids_from_masks(masks)
+ feats = self.apply_rot_to_features(feats, centroids)
+ return feats
+
+ # @average_time_decorator
+ def _agg_patches_bbox(self, masks, timepoints, labels, agg=torch.mean, norm=True, embs=None):
+ """Averages the embeddings of all patches that intersect with the detection.
+
+ Args:
+ - masks (np.ndarray): Masks where each region has a unique label (t x H x W).
+ - timepoints (np.ndarray): For each region, contains the corresponding timepoint. (n_regions)
+ - labels (np.ndarray): Unique labels of the regions. (n_regions)
+ - agg (callable): Aggregation function to use for averaging the embeddings.
+ """
+ try:
+ n_regions = len(timepoints)
+ timepoints_shifted = timepoints - timepoints.min()
+ except ValueError:
+ logger.error("Error: issue computing shifted timepoints.")
+ logger.error(f"Regions: {len(timepoints)}")
+ logger.error(f"Timepoints: {timepoints}")
+ return torch.zeros(n_regions, self.hidden_state_size, device=self.device)
+
+ feats = torch.zeros(n_regions, self.hidden_state_size, device=self.device)
+ patches = []
+ times = np.unique(timepoints_shifted)
+ patches_res = joblib.Parallel(n_jobs=8, backend="threading")(
+ joblib.delayed(self._find_patches_for_masks)(masks[t]) for t in times
+ )
+ patches = {t: patch for t, patch in zip(times, patches_res)}
+ # logger.debug(f"Patches : {patches}")
+
+ embeddings = self._load_features() if embs is None else embs
+
+ def process_region(i, t):
+ patches_feats = []
+ for patch in patches[t][labels[i]]:
+ embs = embeddings[t][patch[1] * self.final_grid_size[1] + patch[0]]
+ embs = FeatureExtractor.normalize_tensor(embs, norm=norm)
+ patches_feats.append(embs)
+ aggregated = agg(torch.stack(patches_feats), dim=0)
+ # If agg is torch.max, extract only the values
+ if isinstance(aggregated, torch.return_types.max):
+ aggregated = aggregated.values
+ return aggregated
+
+ res = joblib.Parallel(n_jobs=8, backend="threading")(
+ joblib.delayed(process_region)(i, t) for i, t in enumerate(timepoints_shifted)
+ )
+
+ for i, r in enumerate(res):
+ feats[i] = r
+
+ return feats
+
+ def _agg_patches_debug_view(self, v, region_mask, lab=None):
+ """Debug function to visualize the patches and their embeddings."""
+ # Add region mask
+ v.add_labels(
+ region_mask,
+ name=f"Region Mask {lab}",
+ opacity=0.5,
+ blending="translucent",
+ )
+
+ def _view_embeddings(self, embeddings, context: dict | None = None, **kwargs):
+ # If this causes issues with augmented feature computation because
+ # the extractor grid size depends on image size,
+ # redefine it as appropriate in the subclass.
+ # Use context to pass information as needed for the use case.
+
+ # Currently redefined in :
+ # - CoTrackerFeatures
+ embs = embeddings.view(
+ -1, self.final_grid_size[0], self.final_grid_size[1], self.hidden_state_size
+ )
+ return embs, self.final_grid_size
+
+ def _agg_patches_exact(self, masks, timepoints, labels, agg=torch.mean, norm=True, embs=None):
+ """Aggregates the embeddings of all patches that strictly belong to the mask."""
+ try:
+ n_regions = len(timepoints)
+ timepoints_shifted = timepoints - timepoints.min()
+ except ValueError:
+ logger.error("Error: issue computing shifted timepoints.")
+ logger.error(f"Regions: {len(timepoints)}")
+ logger.error(f"Timepoints: {timepoints}")
+ return torch.zeros(n_regions, self.hidden_state_size, device=self.device)
+
+ feats = torch.zeros(n_regions, self.hidden_state_size, device=self.device)
+ embeddings = self._load_features() if embs is None else embs
+ embeddings, grid_size = self._view_embeddings(embeddings, context={"masks_shape": masks.shape})
+
+ _T, H, W = masks.shape
+ # assert embeddings.shape[0] == _T, f"Embeddings times {embeddings.shape} does not match masks times {_T}."
+ grid_H, grid_W = grid_size
+ scale_y = grid_H / H
+ scale_x = grid_W / W
+
+ if self._debug_view:
+ import napari
+ if napari.current_viewer() is None:
+ v = napari.Viewer()
+ else:
+ v = napari.current_viewer()
+ if "Masks" not in v.layers:
+ v.add_labels(masks[0], name="Masks")
+ if "Embeddings" not in v.layers:
+ embs = embeddings.view(
+ -1, self.final_grid_size[0], self.final_grid_size[1], self.hidden_state_size
+ )
+ v.add_image(
+ embs.permute(3, 0, 1, 2).cpu().numpy(),
+ name="Embeddings",
+ colormap="inferno",
+ )
+
+ def process_region(i, t, masks):
+ if masks.shape[0] == 1:
+ masks = masks.squeeze(0)
+ mask_reg = masks == labels[i]
+ else:
+ mask_reg = masks[t] == labels[i]
+ if not np.any(mask_reg):
+ logger.warning(f"No pixels found for region {labels[i]} at timepoint {t}.")
+ # return torch.zeros(self.hidden_state_size, device=self.device) # small values to avoid zero divs etc.
+ return torch.fill(self.hidden_state_size, 1e-8, device=self.device)
+
+ y_idxs, x_idxs = np.nonzero(mask_reg)
+ grid_y = np.clip((y_idxs * scale_y).astype(int), 0, grid_H - 1)
+ grid_x = np.clip((x_idxs * scale_x).astype(int), 0, grid_W - 1)
+ patch_embeddings = embeddings[timepoints[i]][grid_y, grid_x]
+ # normalizing before the mean seems most effective
+ patch_embeddings = FeatureExtractor.normalize_tensor(patch_embeddings, norm=norm)
+
+ if self._debug_view:
+ mask_emb = np.zeros((grid_H, grid_W), dtype=np.uint16)
+ mask_emb[grid_y, grid_x] = labels[i]
+ self._agg_patches_debug_view(v, mask_emb, labels[i])
+
+ if patch_embeddings.shape[0] == 0:
+ logger.warning(f"No mapped pixels for region {labels[i]} at timepoint {t}.")
+ return torch.zeros(self.hidden_state_size, device=self.device)
+ return agg(patch_embeddings, dim=0)
+
+ # Parallel processing
+ if self.parallel:
+ res = joblib.Parallel(n_jobs=8, backend="threading")(
+ joblib.delayed(process_region)(i, t, masks=masks) for i, t in enumerate(timepoints_shifted)
+ )
+ for i, r in enumerate(res):
+ # If agg is torch.max or torch.median, extract only the values
+ if isinstance(r, torch.return_types.max) or isinstance(r, torch.return_types.median):
+ feats[i] = r.values
+ else:
+ feats[i] = r
+ else:
+ for i, t in enumerate(timepoints_shifted):
+ feats[i] = process_region(i, t, masks)
+ if self._debug_view:
+ napari.run()
+ # if norm:
+ # feats = feats / feats.norm(dim=-1, keepdim=True)
+ return feats
+
+ def _exact_patch(self, masks, timepoints, labels, norm=False):
+ """Returns all embeddings overlapping with the mask of each object."""
+ try:
+ n_regions = len(timepoints)
+ timepoints_shifted = timepoints - timepoints.min()
+ except ValueError:
+ logger.error("Error: issue computing shifted timepoints.")
+ logger.error(f"Regions: {len(timepoints)}")
+ logger.error(f"Timepoints: {timepoints}")
+ return torch.zeros(n_regions, self.hidden_state_size, device=self.device)
+
+ feats = torch.zeros(n_regions, self.hidden_state_size, device=self.device)
+ embeddings = self._load_features()
+ embeddings = embeddings.view(
+ -1, self.final_grid_size[0], self.final_grid_size[1], self.hidden_state_size
+ )
+
+ _T, H, W = masks.shape
+ grid_H, grid_W = self.final_grid_size
+ scale_y = grid_H / H
+ scale_x = grid_W / W
+
+ def process_region(i, t, masks):
+ if masks.shape[0] == 1: # single timepoint
+ masks = masks.squeeze(0)
+ mask_reg = masks == labels[i]
+ else: # all timepoints
+ mask_reg = masks[t] == labels[i]
+ if not np.any(mask_reg):
+ logger.warning(f"No pixels found for region {labels[i]} at timepoint {t}.")
+ return torch.zeros(self.hidden_state_size, device=self.device)
+
+ y_idxs, x_idxs = np.nonzero(mask_reg)
+ grid_y = np.clip((y_idxs * scale_y).astype(int), 0, grid_H - 1)
+ grid_x = np.clip((x_idxs * scale_x).astype(int), 0, grid_W - 1)
+ patch_embeddings = embeddings[timepoints[i]][grid_y, grid_x]
+ if patch_embeddings.shape[0] == 0:
+ logger.warning(f"No mapped pixels for region {labels[i]} at timepoint {t}.")
+ return torch.zeros(self.hidden_state_size, device=self.device)
+ return patch_embeddings
+
+ # Parallel processing
+ res = joblib.Parallel(n_jobs=8, backend="threading")(
+ joblib.delayed(process_region)(i, t, masks=masks) for i, t in enumerate(timepoints_shifted)
+ )
+ for i, r in enumerate(res):
+ feats[i] = r
+ # for i, t in enumerate(timepoints_shifted):
+ # feats[i] = process_region(i, t, masks)
+
+ if norm:
+ feats = feats / feats.norm(dim=-1, keepdim=True)
+ return feats
+
+ def _save_features(self, features): # , timepoint):
+ """Saves the features to disk."""
+ # save_path = self.save_path / f"{timepoint}_{self.model_name_path}_features.npy"
+ self.embeddings = features
+ if not self.do_save:
+ return
+ save_path = self.save_path / f"{self.model_name_path}_features.npy"
+ np.save(save_path, features.cpu().numpy())
+ assert save_path.exists(), f"Failed to save features to {save_path}"
+
+ def _load_features(self): # , timepoint):
+ """Loads the features from disk."""
+ # load_path = self.save_path / f"{timepoint}_{self.model_name_path}_features.npy"
+ if self.embeddings is None:
+ load_path = self.save_path / f"{self.model_name_path}_features.npy"
+ if load_path.exists():
+ features = np.load(load_path)
+ assert features is not None, f"Failed to load features from {load_path}"
+ if np.any(np.isnan(features)):
+ raise RuntimeError(f"NaN values found in features loaded from {load_path}.")
+ # check feature shape consistency
+ if features.shape[1] != self.final_grid_size[0] * self.final_grid_size[1] or features.shape[2] != self.hidden_state_size:
+ logger.error(f"Saved embeddings found, but shape {features.shape} does not match expected shape {('n_frames', self.final_grid_size[0] * self.final_grid_size[1], self.hidden_state_size)}.")
+ logger.error("Embeddings will be recomputed.")
+ return None
+ logger.info("Saved embeddings loaded.")
+ self.embeddings = torch.tensor(features).to(self.device)
+ return self.embeddings
+ else:
+ logger.info(f"No saved embeddings found at {load_path}. Features will be computed.")
+ return None
+ else:
+ return self.embeddings
+
+ def _check_missing_embeddings(self):
+ """Checks if embeddings for the model already exist or are missing.
+
+ Returns whether the embeddings need to be recomputed.
+ """
+ if self.force_recompute:
+ return True
+ try:
+ features = self._load_features()
+ except FileNotFoundError:
+ return True
+ if features is None:
+ return True
+ else:
+ logger.info(f"Embeddings for {self.model_name} already exist. Skipping embedding computation.")
+ return False
+
+
+class FeatureExtractorAugWrapper:
+ """Wrapper for the FeatureExtractor class to apply augmentations."""
+ def __init__(
+ self,
+ extractor: FeatureExtractor,
+ augmenter: "PretrainedAugmentations",
+ n_aug: int = 1,
+ force_recompute: bool = False,
+ ):
+ self.extractor = extractor
+ self.additional_features = extractor.additional_features
+ self.n_aug = n_aug
+ self.aug_pipeline = augmenter
+ self.all_aug_features = {} # n_aug -> {aug_id: {metadata, data}}
+ # data -> {t: {lab: {"coords": coords, "features": features}}}
+ self.image_shape_reference = {}
+
+ self.extractor.force_recompute = True
+ self.extractor.do_save = False # do not save intermediate features (augmented image embeddings)
+ # instead, we will save the augmented features + coordinates on a per-object basis in a zarr store
+ self.extractor.do_normalize = False
+ self.extractor.parallel = False
+ # already parallelized, faster this way since it avoids the overhead of spawning
+ # many small processes within the parallelized augmentation pipeline
+
+ self._zarr_sync = zarr.ProcessSynchronizer(str(self.get_save_path()) + ".sync")
+ self.force_recompute = force_recompute
+
+ self._debug_view = None
+
+ def get_save_path(self):
+ root_path = self.extractor.save_path / "aug"
+ if not root_path.exists():
+ root_path.mkdir(parents=True, exist_ok=True)
+ return root_path / f"{self.extractor.model_name_path}_aug.zarr"
+
+ def _check_existing(self):
+ save_path = self.get_save_path()
+ if not save_path.exists() or self.force_recompute:
+ logger.debug(f"Augmentation zarr store {save_path} does not exist or force_recompute is True. Recomputing features.")
+ return False, None, None
+ logger.info(f"Loading existing features from {save_path}...")
+ # root = zarr.open_group(str(save_path), mode="r")
+ # existing_augs = [k for k in root.keys() if k.isdigit()]
+ features_dict = self.load_all_features()
+ existing_augs = list(features_dict.keys())
+ logger.info("Done.")
+ return True, existing_augs, features_dict
+
+ def _compute(self, images, masks):
+ """Computes the features for the images and masks."""
+ images_shape = images.shape
+ if len(images_shape) != 3:
+ images_shape = images_shape[1:] # remove batch dimension if present
+ embs = self.extractor.precompute_image_embeddings(images, image_shape=images_shape)
+
+ if self._debug_view is not None:
+ embs = embs.cpu().numpy()
+ logger.debug(f"Embeddings shape: {embs.shape}")
+ embs = embs.reshape(-1, self.extractor.final_grid_size[0], self.extractor.final_grid_size[1], self.extractor.hidden_state_size)
+ embs = np.moveaxis(embs, 3, 0)
+ self._debug_view.add_image(embs, name="Embeddings", colormap="inferno")
+ self._debug_view.add_image(images.cpu().numpy(), name="Images", colormap="viridis")
+ self._debug_view.add_labels(masks.cpu().numpy(), name="Masks")
+
+ images, masks = images.cpu().numpy(), masks.cpu().numpy()
+ # features = wrfeat.WRAugPretrainedFeatures.from_mask_img(
+ # img=images,
+ # mask=masks,
+ # feature_extractor=self.extractor,
+ # t_start=0,
+ # additional_properties=self.extractor.additional_features,
+ # )
+ features = [
+ wrfeat.WRAugPretrainedFeatures.from_mask_img(
+ # embeddings=embs,
+ img=img[np.newaxis],
+ mask=mask[np.newaxis],
+ feature_extractor=self.extractor,
+ t_start=t,
+ additional_properties=self.extractor.additional_features
+ )
+ for t, (mask, img) in tqdm(
+ enumerate(zip(masks, images)), desc="Computing features...", total=len(masks), leave=False
+ ) # if t == 10 debug
+ ]
+ features_dict = {t: v for f in features for t, v in f.to_dict().items()}
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ return features_dict
+
+ def _compute_original(self, images, masks):
+ """Computes the original features for the images and masks."""
+ images, masks = self.aug_pipeline.preprocess(images, masks, normalize_func=self.extractor.normalize_array)
+ orig_feat_dict = self._compute(images, masks)
+ self.image_shape_reference[0] = images.shape[-2:]
+ return orig_feat_dict
+
+ def _compute_augmented(self, images, masks, n):
+ images, masks = self.aug_pipeline.preprocess(images, masks, normalize_func=self.extractor.normalize_array)
+ aug_images, aug_masks, aug_record = self.aug_pipeline(images, masks)
+
+ # check for NaNs
+ if torch.isnan(aug_images).any() or torch.isnan(aug_masks).any():
+ raise RuntimeError("NaN values found in augmented images or masks.")
+
+ im_shape, masks_shape = aug_images.shape, aug_masks.shape
+ assert im_shape == masks_shape, f"Augmented images shape {im_shape} does not match augmented masks shape {masks_shape}."
+ if im_shape[-2:] != self.extractor.orig_image_size:
+ # if isinstance(self.extractor, TAPFeatures):
+ # self.extractor.final_grid_size = (im_shape[-2], im_shape[-1]) # TAP features have same dims as images
+ if isinstance(self.extractor, CoTrackerFeatures):
+ stride = self.extractor.model.stride
+ self.extractor.final_grid_size = (im_shape[-2] // stride, im_shape[-1] // stride)
+ if im_shape[-1] == 0 or im_shape[-2] == 0:
+ raise ValueError(f"Augmented images have invalid shape {im_shape}. Cannot extract features.")
+ self.extractor.orig_image_size = im_shape[-2:]
+
+ aug_feat_dict = self._compute(aug_images, aug_masks)
+ self.image_shape_reference[n] = aug_images.shape[-2:]
+ return aug_feat_dict, aug_record
+
+ def _process_aug(self, n, images, masks, existing_aug_ids, existing_features_dict):
+ try:
+ if str(f"{n + 1}") in existing_aug_ids:
+ logger.info(f"Augmentation {n + 1} already exists. Skipping computation.")
+ aug_feat_dict = existing_features_dict[str(n + 1)]["data"]
+ aug_record = existing_features_dict[str(n + 1)]["metadata"]
+ else:
+ aug_feat_dict, aug_record = self._compute_augmented(images, masks, n=n + 1)
+ result = {
+ "n": n,
+ "metadata": aug_record,
+ "data": aug_feat_dict,
+ # "should_save": str(n + 1) not in existing_aug_ids,
+ }
+ if str(n + 1) not in existing_aug_ids:
+ self._save_features(n + 1, result)
+ return result
+ except Exception as e:
+ logger.error(f"Error processing augmentation {n + 1}: {e}")
+ raise e
+
+ def compute_all_features(self, images, masks, clear_mem=True, n_workers=8) -> dict:
+ """Augments the images and masks, computes the embeddings, and saves features incrementally."""
+ # check existing features
+ present, existing_augs, existing_features_dict = self._check_existing()
+ save_path = self.get_save_path()
+ if present:
+ logger.debug(f"Saved features found at {save_path}.")
+ if len(existing_augs) == self.n_aug + 1:
+ logger.info(f"All {self.n_aug} augmentations + original already exist. Loading existing features.")
+ self.all_aug_features = existing_features_dict
+ return self.all_aug_features
+ else:
+ logger.info("No existing augmentations found.")
+ logger.debug(f"Existing augmentations: {existing_augs}")
+ if existing_augs is None:
+ existing_aug_ids = []
+ else:
+ existing_aug_ids = existing_augs
+ logger.debug(f"Existing augmentations IDs: {existing_aug_ids}")
+
+ if "0" not in existing_aug_ids:
+ orig_feat_dict = self._compute_original(images, masks)
+ else:
+ logger.info("Original features already exist. Skipping computation for original features.")
+ orig_feat_dict = existing_features_dict["0"]["data"]
+
+ self.all_aug_features = {
+ "0": {
+ "data": orig_feat_dict,
+ "metadata": {
+ "image_shape": images.shape,
+ }
+ }
+ }
+ if "0" not in existing_aug_ids:
+ self._save_features(0, self.all_aug_features["0"])
+
+ disable_parallel = False
+ if isinstance(self.extractor, CoTrackerFeatures) or isinstance(self.extractor, TAPFeatures) or isinstance(self.extractor, MicroSAMFeatures):
+ # CoTrackerFeatures and TAP uses a different grid size for each image,
+ # which requires a different approach to parallel processing.
+ # As a quick fix, parallel processing is disabled
+ # TODO make necessary changes to CoTrackerFeatures to allow parallel processing
+ disable_parallel = True
+ logger.debug(f"Disabling parallel processing for {self.extractor.__class__.__name__} due to variable grid size.")
+
+ if n_workers == 0 or disable_parallel:
+ for n in range(self.n_aug):
+ res = self._process_aug(n, images, masks, existing_aug_ids, existing_features_dict)
+ self.all_aug_features[str(n + 1)] = res
+ # if res["should_save"]:
+ # self._save_features(n + 1, self.all_aug_features[str(n + 1)])
+ else:
+ # joblib parallel processing
+ results = joblib.Parallel(n_jobs=n_workers, backend="threading")(
+ joblib.delayed(self._process_aug)(
+ n, images, masks, existing_aug_ids, existing_features_dict
+ ) for n in range(self.n_aug)
+ )
+ for res in results:
+ n = res["n"]
+ self.all_aug_features[str(n + 1)] = {
+ "metadata": res["metadata"],
+ "data": res["data"],
+ }
+ # if res["should_save"]:
+ # self._save_features(n + 1, self.all_aug_features[str(n + 1)])
+
+ if clear_mem:
+ self.extractor.embeddings = self.extractor.embeddings.cpu()
+ try:
+ self.extractor.model = self.extractor.model.cpu()
+ except AttributeError as e:
+ logger.error(f"Model attribute not found: {e}. Skipping model transfer to CPU.")
+ self.extractor = None
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return self.all_aug_features
+
+ # def _create_feat_dict(self, labels, ts, coords, features):
+ # """Creates a dictionary with the augmented features."""
+ # aug_feat_dict = {}
+ # features = features.cpu().numpy()
+ # for i, (t, lab) in enumerate(zip(ts, labels)):
+ # t = int(t)
+ # lab = int(lab)
+ # if t not in aug_feat_dict:
+ # aug_feat_dict[t] = {}
+ # aug_feat_dict[t][lab] = {
+ # "coords": coords[i],
+ # "features": features[i],
+ # }
+ # return aug_feat_dict
+
+ def _save_features(self, aug_id: int, aug_data: dict):
+ """Saves the features for a specific augmentation to disk as zarr (fast, flat layout)."""
+ import zarr
+
+ save_path = self.get_save_path()
+ compressor = Blosc(cname='zstd', clevel=3, shuffle=Blosc.BITSHUFFLE)
+ root = zarr.open_group(str(save_path), mode="a", synchronizer=self._zarr_sync)
+ group_name = str(aug_id)
+ if group_name in root:
+ del root[group_name]
+ group = root.create_group(group_name)
+ group.attrs["metadata"] = json.dumps(aug_data.get("metadata", {}))
+
+ coords_list = []
+ t_list = []
+ lab_list = []
+ features_dict = {}
+
+ for t, data in aug_data["data"].items():
+ for lab, lab_data in data.items():
+ coords_list.append(lab_data["coords"])
+ t_list.append(t)
+ lab_list.append(lab)
+ for key, value in lab_data["features"].items():
+ if key not in features_dict:
+ features_dict[key] = []
+ features_dict[key].append(value)
+
+ coords_arr = np.stack(coords_list)
+ t_arr = np.array(t_list)
+ lab_arr = np.array(lab_list)
+ group.create_dataset("coords", data=coords_arr, compressor=compressor)
+ group.create_dataset("timepoints", data=t_arr, compressor=compressor)
+ group.create_dataset("labels", data=lab_arr, compressor=compressor)
+
+ features_group = group.create_group("features")
+ for key, values in features_dict.items():
+ features_group.create_dataset(key, data=np.stack(values), compressor=compressor)
+
+ # logger.debug(f"Augmented features for augmentation {aug_id} saved to {save_path}.")
+
+ def load_all_features(self) -> dict:
+ """Loads all features from disk."""
+ save_path = self.get_save_path()
+ if not save_path.exists():
+ raise FileNotFoundError(f"Path {save_path} does not exist.")
+
+ features = FeatureExtractorAugWrapper.load_features(
+ save_path,
+ additional_props=self.additional_features,
+ )
+ self.all_aug_features = features
+ return features
+
+ @staticmethod
+ def load_features(path: str | Path, additional_props: str | None = None, n_jobs: int = 12) -> dict:
+ """Loads the features for all augmentations from disk (flat zarr layout, parallelized)."""
+ import joblib
+
+ if not isinstance(path, Path):
+ path = Path(path)
+ if not path.exists():
+ raise FileNotFoundError(f"Path {path} does not exist.")
+
+ if additional_props is not None:
+ required_features = wrfeat._PROPERTIES[additional_props]
+ if len(required_features) == 0:
+ required_features = "pretrained_feats"
+ else:
+ required_features += ("pretrained_feats", )
+
+ root = zarr.open_group(str(path), mode="r")
+ aug_ids = [k for k in root.keys() if k.isdigit()]
+
+ def _load_single(aug_id):
+ missing = False
+ if aug_id not in root:
+ missing = True
+ group = root[aug_id]
+ try:
+ metadata = json.loads(group.attrs["metadata"])
+ except KeyError:
+ metadata = None
+ try:
+ coords_arr = group["coords"][...]
+ t_arr = group["timepoints"][...]
+ lab_arr = group["labels"][...]
+ features_group = group["features"]
+ features_dict = {}
+ for key in features_group.keys():
+ if additional_props is not None and key not in required_features:
+ continue
+ features_dict[key] = features_group[key][...]
+ if additional_props is not None:
+ missing_keys = [k for k in required_features if k not in features_dict]
+ if missing_keys:
+ raise RuntimeError(
+ f"Missing required features {missing_keys} in augmentation {aug_id}. "
+ f"Please delete the cache at {path} and recompute the features."
+ )
+
+ data = {}
+ for i in range(len(t_arr)):
+ t = int(t_arr[i])
+ lab = int(lab_arr[i])
+ if t not in data:
+ data[t] = {}
+ feats = {k: features_dict[k][i] for k in features_dict}
+ data[t][lab] = {
+ "coords": coords_arr[i],
+ "features": feats,
+ }
+ return aug_id, {"metadata": metadata, "data": data}
+ except KeyError as e:
+ logger.error(f"KeyError: {e} - Augmentation {aug_id} is missing some data. Skipping.")
+ missing = True
+ if missing:
+ return aug_id, None
+ raise RuntimeError(f"Augmentation {aug_id} could not be loaded. Missing data or invalid format.")
+
+ results = joblib.Parallel(n_jobs=n_jobs, backend="threading")(
+ joblib.delayed(_load_single)(aug_id) for aug_id in aug_ids
+ )
+ all_data = {aug_id: aug_data for aug_id, aug_data in results if aug_data is not None}
+ return all_data
+
+
+##############
+@register_backbone("facebook/hiera-tiny-224-hf", 768)
+class HieraFeatures(FeatureExtractor):
+ model_name = "facebook/hiera-tiny-224-hf"
+
+ def __init__(
+ self,
+ image_size: tuple[int, int],
+ save_path: str | Path,
+ batch_size: int = 16,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ mode: PretrainedFeatsExtractionMode = "nearest_patch",
+ **kwargs,
+ ):
+ super().__init__(image_size, save_path, batch_size, device, mode, **kwargs)
+ # self.input_size = 224
+ self.input_mul = 3
+ self.input_size = int(self.input_mul * 224)
+ self.final_grid_size = int(7 * self.input_mul) # default is 7x7 grid
+ self.n_channels = 3
+ self.hidden_state_size = 768
+ self.rescale_batches = False
+
+ ##
+ self.im_proc_kwargs["size"] = self.input_size
+ ##
+ self.image_processor = AutoImageProcessor.from_pretrained(self.model_name)
+ config = HieraConfig.from_pretrained(self.model_name)
+ config.image_size = [self.input_size[0], self.input_size[1]]
+ # logger.debug(f"Config: {config}")
+ # self.model = HieraModel.from_pretrained(self.model_name)
+ # self.model.config.image_size = [self.input_size, self.input_size]
+ self.model = HieraModel(config)
+ self.model.to(self.device)
+ # self.model.embeddings.patch_embeddings.num_patches = (self.input_size // self.model.config.patch_size[0]) ** 2
+ # self.model.embeddings.position_embeddings = torch.nn.Parameter(
+ # torch.zeros(1, self.model.embeddings.patch_embeddings.num_patches + 1, self.hidden_state_size)
+ # )
+ # self.model.embeddings.position_ids = torch.arange(0, self.model.embeddings.patch_embeddings.num_patches + 1).unsqueeze(0)
+
+ def _run_model(self, images, **kwargs) -> torch.Tensor:
+ """Extracts embeddings from the model."""
+ # images = self._normalize_batch(images)
+ inputs = self.image_processor(images, **self.im_proc_kwargs).to(self.device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ return outputs.last_hidden_state
+
+
+@register_backbone("facebook/dinov2-base", 768)
+class DinoV2Features(FeatureExtractor):
+ model_name = "facebook/dinov2-base"
+
+ def __init__(
+ self,
+ image_size: tuple[int, int],
+ save_path: str | Path,
+ batch_size: int = 16,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ mode: PretrainedFeatsExtractionMode = "nearest_patch",
+ **kwargs,
+ ):
+ super().__init__(image_size, save_path, batch_size, device, mode)
+ self.input_size = 224
+ self.final_grid_size = 16 # 16x16 grid
+ self.n_channels = 3 # expects RGB images
+ self.hidden_state_size = 768
+ self.image_processor = AutoImageProcessor.from_pretrained(self.model_name)
+ ##
+ self.im_proc_kwargs["size"] = self.input_size
+ ##
+ self.model = AutoModel.from_pretrained(self.model_name)
+ self.rescale_batches = False
+ # config = Dinov2Config.from_pretrained(self.model_name)
+ # config.image_size = self.input_size
+
+ # self.model = Dinov2Model(config)
+ # logger.info(f"Model from config: {self.model.config}")
+ # logger.info(f"Pretrained model : {Dinov2Model.from_pretrained(self.model_name).config}")
+ self.model.to(self.device)
+
+ def _run_model(self, images, **kwargs) -> torch.Tensor:
+ """Extracts embeddings from the model."""
+ inputs = self.image_processor(images, **self.im_proc_kwargs).to(self.device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ # ignore the CLS token (not classifying)
+ # this way we get only the patch embeddings
+ # which are compatible with finding the relevant patches directly
+ # in the rest of the code
+ return outputs.last_hidden_state[:, 1:, :]
+
+
+@register_backbone("facebook/sam-vit-base", 256)
+class SAMFeatures(FeatureExtractor):
+ model_name = "facebook/sam-vit-base"
+
+ def __init__(
+ self,
+ image_size: tuple[int, int],
+ save_path: str | Path,
+ batch_size: int = 4,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ mode: PretrainedFeatsExtractionMode = "nearest_patch",
+ **kwargs,
+ ):
+ super().__init__(image_size, save_path, batch_size, device, mode)
+ self.input_size = 1024
+ self.final_grid_size = 64 # 64x64 grid
+ self.n_channels = 3
+ self.hidden_state_size = 256
+ self.image_processor = SamProcessor.from_pretrained(self.model_name)
+ self.model = SamModel.from_pretrained(self.model_name)
+ self.rescale_batches = False
+
+ self.model.to(self.device)
+
+ def _run_model(self, images, **kwargs) -> torch.Tensor:
+ """Extracts embeddings from the model."""
+ inputs = self.image_processor(images, **self.im_proc_kwargs).to(self.device)
+ outputs = self.model.get_image_embeddings(inputs['pixel_values'])
+ B, N, H, W = outputs.shape
+ return outputs.permute(0, 2, 3, 1).reshape(B, H * W, N) # (B, grid_size**2, hidden_state_size)
+
+
+@register_backbone("facebook/sam2-hiera-large", 256)
+@register_backbone("facebook/sam2.1-hiera-base-plus", 256)
+class SAM2Features(FeatureExtractor):
+ model_name = "facebook/sam2.1-hiera-base-plus"
+
+ def __init__(
+ self,
+ image_size: tuple[int, int],
+ save_path: str | Path,
+ batch_size: int = 4,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ mode: PretrainedFeatsExtractionMode = "nearest_patch",
+ **kwargs,
+ ):
+ super().__init__(image_size, save_path, batch_size, device, mode)
+ self.input_size = 1024
+ self.final_grid_size = 64 # 64x64 grid
+ self.n_channels = 3
+ self.hidden_state_size = 256
+ self.model = SAM2ImagePredictor.from_pretrained(self.model_name, device=self.device).model
+
+ self.batch_return_type = "list[np.ndarray]"
+ self.channel_first = True
+ self.rescale_batches = False
+
+ self._bb_feat_sizes = [
+ (256, 256),
+ (128, 128),
+ (64, 64),
+ ]
+ if self.rescale_batches:
+ print("Rescaling batches to [0, 255] range.")
+
+ @torch.no_grad()
+ def _run_model(self, images: list[np.ndarray], **kwargs) -> torch.Tensor:
+ """Extracts embeddings from the model."""
+ with torch.autocast(device_type=self.device), torch.inference_mode():
+ images_ten = torch.stack([torch.tensor(image) for image in images]).to(self.device)
+ # logger.debug(f"Image dtype: {images_ten.dtype}")
+ # logger.debug(f"Image shape: {images_ten.shape}")
+ # logger.debug(f"Image min : {images_ten.min()}, max: {images_ten.max()}")
+ images_ten = F.interpolate(images_ten, size=(self.input_size[0], self.input_size[1]), mode="bilinear", align_corners=False)
+ # from torchvision.transforms.functional import resize
+ # images_ten = resize(images_ten, size=(self.input_size, self.input_size))
+ # images_ten = normalize(images_ten, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ # out = self.model.image_encoder(images_ten)
+ out = self.model.forward_image(images_ten)
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(out)
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
+ if self.model_name != "facebook/sam2.1-hiera-base-plus/highres":
+ if self.model.directly_add_no_mem_embed:
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+ feats = [
+ feat.permute(1, 2, 0).view(feat.shape[1], -1, *feat_size)
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+ ][::-1]
+ features = feats[-1]
+ # features = self.model.set_image_batch(images)
+ # features = self.model._features['image_embed']
+ B, N, H, W = features.shape
+ return features.permute(0, 2, 3, 1).reshape(B, H * W, N) # (B, grid_size**2, hidden_state_size)
+
+
+@register_backbone("facebook/sam2.1-hiera-base-plus/highres", 32)
+class SAM2HighresFeatures(SAM2Features):
+ model_name = "facebook/sam2.1-hiera-base-plus"
+
+ def __init__(
+ self,
+ image_size: tuple[int, int],
+ save_path: str | Path,
+ batch_size: int = 4,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ mode: PretrainedFeatsExtractionMode = "nearest_patch",
+ **kwargs,
+ ):
+ super().__init__(image_size, save_path, batch_size, device, mode)
+ self.final_grid_size = 256 # 256x256 grid
+ # self.final_grid_size = 128
+ self.hidden_state_size = 32
+ # self.hidden_state_size = 64
+
+ @property
+ def model_name_path(self):
+ """Returns the model name for saving."""
+ p = f"{self.model_name}/highres"
+ return p.replace("/", "-")
+
+ def _run_model(self, images: list[np.ndarray], **kwargs) -> torch.Tensor:
+ """Extracts embeddings from the model."""
+ with torch.autocast(device_type=self.device), torch.inference_mode():
+ images_ten = torch.stack([torch.tensor(image) for image in images]).to(self.device)
+ # logger.debug(f"Image dtype: {images_ten.dtype}")
+ # logger.debug(f"Image shape: {images_ten.shape}")
+ # logger.debug(f"Image min : {images_ten.min()}, max: {images_ten.max()}")
+ images_ten = F.interpolate(images_ten, size=(self.input_size[0], self.input_size[1]), mode="bilinear", align_corners=False)
+ out = self.model.forward_image(images_ten)
+ backbone_out, _, _, _ = self.model._prepare_backbone_features(out)
+ features = backbone_out["backbone_fpn"][0] # (B, N, H, W)
+
+ B, N, H, W = features.shape
+ return features.permute(0, 2, 3, 1).reshape(B, H * W, N) # (B, grid_size**2, hidden_state_size)
+
+
+@register_backbone("facebookresearch/co-tracker", 128)
+class CoTrackerFeatures(FeatureExtractor):
+ model_name = "facebookresearch/co-tracker"
+
+ def __init__(
+ self,
+ image_size: tuple[int, int],
+ save_path: str | Path,
+ batch_size: int = 4,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ mode: PretrainedFeatsExtractionMode = "nearest_patch",
+ **kwargs,
+ ):
+ super().__init__(image_size, save_path, batch_size, device, mode)
+ self.input_size = image_size
+ cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline")
+ self.model = cotracker.model
+ self.model.to(device)
+ self.final_grid_size = (image_size[0] // self.model.stride, image_size[1] // self.model.stride)
+ self.hidden_state_size = 128
+ self.n_channels = 3
+ self.fmaps_chunk_size = 8
+
+ self.batch_return_type = "list[np.ndarray]"
+
+ def _view_embeddings(self, embeddings, context: dict, **kwargs):
+ image_shape = context.get("masks_shape", None)
+ H, W = image_shape[-2:]
+ grid_size = (H // self.model.stride, W // self.model.stride)
+ embs = embeddings.view(
+ -1,
+ *grid_size,
+ self.hidden_state_size,
+ )
+ return embs, grid_size
+
+ def precompute_image_embeddings(self, images, image_shape=None, **kwargs): # , windows, window_size):
+ """Precomputes embeddings for all images."""
+ try:
+ if image_shape is None:
+ _, H, W = images.shape
+ if H != self.input_size[0] or W != self.input_size[1]:
+ self.input_size = (H, W)
+ self.final_grid_size = (H // self.model.stride, W // self.model.stride)
+ logger.debug(f"Updated CoTracker input size: {self.input_size}, final grid size: {self.final_grid_size}")
+ grid_size = self.final_grid_size
+ else:
+ H, W = image_shape[-2:]
+ grid_size = (H // self.model.stride, W // self.model.stride)
+
+ missing = self._check_missing_embeddings()
+ all_embeddings = torch.zeros(len(images), grid_size[0] * grid_size[1], self.hidden_state_size, device=self.device)
+ if missing:
+ for ts, batches in tqdm(self._prepare_batches(images), total=len(images) // self.batch_size, desc="Computing embeddings", leave=False):
+ try:
+ embeddings = self._run_model(batches, image_shape, **kwargs)
+ except Exception as e:
+ breakpoint()
+ raise e
+ if torch.any(embeddings.isnan()):
+ raise RuntimeError("NaN values found in features.")
+ # logger.debug(f"Embeddings shape: {embeddings.shape}")
+ all_embeddings[ts] = embeddings.to(torch.float32)
+ assert embeddings.shape[-1] == self.hidden_state_size
+ self.embeddings = all_embeddings
+ self._save_features(all_embeddings)
+ # logger.debug(f"Precomputed embeddings shape: {self.embeddings.shape}")
+ return self.embeddings
+ except Exception as e:
+ logger.error(f"Error occurred while precomputing embeddings: {e}")
+ raise e
+
+ def _run_model(self, images: list[np.ndarray], image_shape=None, **kwargs) -> torch.Tensor:
+ self.model.eval()
+ x = torch.stack([torch.tensor(image) for image in images]).to(self.device)
+ x = x.unsqueeze(0) # B, T, C, H, W
+ with torch.no_grad():
+ B = x.shape[0]
+ T = x.shape[1]
+ C_ = x.shape[2]
+ H, W = x.shape[3], x.shape[4]
+ if T > self.batch_size:
+ fmaps = []
+ for t in range(0, T, self.fmaps_chunk_size):
+ video_chunk = x[:, t : t + self.fmaps_chunk_size]
+ fmaps_chunk = self.model.fnet(video_chunk.reshape(-1, C_, H, W))
+ T_chunk = video_chunk.shape[1]
+ C_chunk, H_chunk, W_chunk = fmaps_chunk.shape[1:]
+ fmaps.append(fmaps_chunk.reshape(B, T_chunk, C_chunk, H_chunk, W_chunk))
+ fmaps = torch.cat(fmaps, dim=1).reshape(-1, C_chunk, H_chunk, W_chunk)
+ else:
+ fmaps = self.model.fnet(x.reshape(-1, C_, H, W))
+ fmaps = fmaps.permute(0, 2, 3, 1)
+ fmaps = fmaps / torch.sqrt(
+ torch.maximum(
+ torch.sum(torch.square(fmaps), axis=-1, keepdims=True),
+ torch.tensor(1e-12, device=fmaps.device),
+ )
+ )
+ fmaps = fmaps.permute(0, 3, 1, 2).reshape(
+ B, -1, self.model.latent_dim, H // self.model.stride, W // self.model.stride
+ ) # B, T, N, H', W'
+ # end of original code
+ fmaps = fmaps.permute(0, 1, 3, 4, 2).squeeze(0) # T, H', W', N
+ fmaps = fmaps.reshape(
+ fmaps.shape[0], fmaps.shape[1] * fmaps.shape[2], fmaps.shape[3]
+ ) # T, H' * W', N
+ return fmaps
+
+
+@register_backbone("microsam/vit_b_lm", 256)
+@register_backbone("microsam/vit_l_lm", 256)
+class MicroSAMFeatures(FeatureExtractor):
+ model_name = "microsam/vit_b_lm"
+
+ def __init__(
+ self,
+ image_size: tuple[int, int],
+ save_path: str | Path,
+ batch_size: int = 4,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ mode: PretrainedFeatsExtractionMode = "nearest_patch",
+ **kwargs,
+ ):
+ if not MICRO_SAM_AVAILABLE:
+ raise ImportError("microSAM is not available. Please install it following the instructions in the documentation.")
+ super().__init__(image_size, save_path, batch_size, device, mode)
+ self.input_size = 1024
+ self.final_grid_size = 64
+ self.n_channels = 3
+ self.hidden_state_size = 256
+ model_name = self.model_name.split("/")[-1]
+ self.model = get_microsam_model(model_name, device=self.device)
+
+ self.batch_return_type = "list[np.ndarray]"
+ self.channel_first = False
+ self.rescale_batches = False
+
+ def _run_model(self, images: list[np.ndarray], **kwargs) -> torch.Tensor:
+ """Extracts embeddings from the model."""
+ embeddings = []
+ for image in images:
+ # logger.debug(f"Image shape: {image.shape}")
+ self.model.set_image(image)
+ embedding = self.model.get_image_embedding() # (1, hidden_state_size, grid_size, grid_size)
+ B, N, H, W = embedding.shape
+ embedding = embedding.permute(0, 2, 3, 1).reshape(B, H * W, N)
+ embeddings.append(embedding)
+ out = torch.stack(embeddings).squeeze() # (B, grid_size**2, hidden_state_size)
+ if len(out.shape) == 2:
+ out = out.unsqueeze(0)
+ return out
+
+
+@register_backbone("weigertlab/tarrow", 32)
+class TAPFeatures(FeatureExtractor):
+ model_name = "weigertlab/tarrow"
+
+ def __init__(
+ self,
+ model_folder: str,
+ image_size: tuple[int, int],
+ save_path: str | Path,
+ batch_size: int = 2,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ mode: PretrainedFeatsExtractionMode = "nearest_patch",
+ **kwargs,
+ ):
+ super().__init__(image_size, save_path, batch_size, device, mode)
+ self._final_grid_size = (self.orig_image_size[0], self.orig_image_size[1])
+ self.input_size = self.final_grid_size
+ self.n_channels = 1
+
+ self.model_folder = model_folder
+ self.full_model, self.hidden_state_size = self._load_model_from_path(model_folder)
+ self.full_model.to(device)
+ self.full_model.eval()
+ AVAILABLE_PRETRAINED_BACKBONES["weigertlab/tarrow"]["feat_dim"] = self.hidden_state_size
+ self.model = self.full_model.backbone
+
+ self.batch_return_type = "torch.Tensor"
+ self.channel_first = False
+ self.rescale_batches = False
+ self.normalize_embeddings = False
+
+ # TODO clear full model from memory
+
+ @property
+ def final_grid_size(self) -> tuple[int, int]:
+ return self._final_grid_size
+
+ @final_grid_size.setter
+ def final_grid_size(self, value: tuple[int, int]):
+ """Sets the final grid size and updates the model's input size."""
+ self._final_grid_size = value
+ self.orig_image_size = value
+ self.input_size = value
+ self.model.input_size = value
+ self.model_patch_size = (1, 1)
+
+ def _set_model_patch_size(self):
+ pass
+
+ @staticmethod
+ def _load_model_from_path(model_folder: str):
+ """Loads the model from the folder."""
+ if not os.path.exists(model_folder):
+ raise FileNotFoundError(f"Model folder {model_folder} does not exist.")
+ model = TimeArrowNet.from_folder(model_folder, from_state_dict=True)
+ feat_dim = model.bb_n_feat
+ return model, feat_dim
+
+ def normalize_array(self, b):
+ images_batch = tap_normalize(b)
+ images_batch = torch.from_numpy(images_batch).to(torch.float32) # T, H, W
+ return images_batch
+
+ def _prepare_batches(self, images):
+ if self.do_normalize:
+ images = self.normalize_array(images)
+ images = images.unsqueeze(1) # T, C, H, W
+ return images
+
+ def _run_model(self, images: list[np.ndarray], **kwargs) -> torch.Tensor:
+ """Extracts embeddings from the model."""
+ features = []
+ ts = images.shape[0]
+ im_shape = tuple(images.shape[-2:])
+ self.orig_image_size = im_shape
+ self.final_grid_size = im_shape
+ with torch.no_grad():
+ for i in tqdm(range(0, len(images), self.batch_size), desc="Computing TAP features", leave=False):
+ batch = images[i : i + self.batch_size]
+ batch = batch.to(self.device)
+ out = self.model(batch)
+ features.append(out)
+
+ features = torch.cat(features, dim=0).cpu()
+
+ if self.device == "cuda":
+ torch.cuda.empty_cache()
+
+ features = features.moveaxis(1, 3) # (T, H, W, N)
+ features = features.reshape(ts, self.final_grid_size[0] * self.final_grid_size[1], self.hidden_state_size) # (T, grid_size**2, N)
+ return features
+
+ def precompute_image_embeddings(self, images, **kwargs):
+ # missing = self._check_missing_embeddings()
+ if images.shape[-2:] != self.orig_image_size:
+ self.orig_image_size = images.shape[-2:]
+ self.final_grid_size = images.shape[-2:]
+ batches = self._prepare_batches(images)
+ self.embeddings = self._run_model(batches, **kwargs)
+ # self._save_features(self.embeddings)
+ return self.embeddings
+
+
+@register_backbone("mouseland/cellpose-sam", 192)
+class CellposeSAMFeatures(FeatureExtractor):
+ model_name = "mouseland/cellpose-sam"
+
+ def __init__(
+ self,
+ image_size: tuple[int, int],
+ save_path: str | Path,
+ batch_size: int = 4,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ mode: PretrainedFeatsExtractionMode = "nearest_patch",
+ **kwargs,
+ ):
+ if not CELLPOSE_AVAILABLE:
+ raise ImportError("Cellpose is not available. Please install it following the instructions in the documentation.")
+ super().__init__(image_size, save_path, batch_size, device, mode)
+ self.input_size = 256
+ self.final_grid_size = 32 # 32x32 grid
+ self.n_channels = 3
+ self.hidden_state_size = 192
+ self.model = CellposeSAM()
+ self.model.to(self.device)
+ self.model.eval()
+
+ self.batch_return_type = "np.ndarray"
+ self.channel_first = False
+ self.rescale_batches = False
+
+ def normalize_array(self, images_batch):
+ batch = torch.zeros(
+ (*tuple(images_batch.shape), self.n_channels), # add a channel dimension
+ dtype=torch.float32
+ )
+ for n, b in enumerate(images_batch):
+ if isinstance(b, torch.Tensor):
+ b = b.cpu().numpy()
+ b_ = cp_transforms.convert_image(
+ b,
+ channel_axis=None,
+ z_axis=None,
+ do_3D=False
+ )
+ b_ = cp_transforms.normalize_img(b_)
+ b_ = torch.from_numpy(b_).to(torch.float32)
+ batch[n] = b_
+ logger.debug(f"Cellpose SAM batch shape: {batch.shape}")
+ return batch[..., 0] # keep only a single copy of the channel
+
+ def _prepare_batches(self, images):
+ if self.do_normalize:
+ images = self.normalize_array(images)
+ for i in range(0, len(images), self.batch_size):
+ end = i + self.batch_size
+ end = min(end, len(images))
+ batch = images[i:end] # (B, H, W)
+ ts = range(i, end)
+ # if self.do_normalize:
+ # batch = self.normalize_array(batch)
+ if len(batch.shape) == 3:
+ batch = batch.unsqueeze(1)
+ batch = batch.repeat(1, 3, 1, 1)
+ batch = batch.to(self.device)
+ yield ts, batch
+
+ def _run_model(self, images_batch: np.ndarray, **kwargs) -> torch.Tensor:
+ embeddings = []
+
+ with torch.no_grad():
+
+ b = F.interpolate(images_batch, size=(self.input_size[0], self.input_size[1]), mode="bilinear", align_corners=False)
+
+ x = self.model.encoder.patch_embed(b)
+ if self.model.encoder.pos_embed is not None:
+ x = x + self.model.encoder.pos_embed
+ for i, blk in enumerate(self.model.encoder.blocks):
+ x = blk(x)
+ x = self.model.encoder.neck(x.permute(0, 3, 1, 2))
+ x = self.model.out(x) # (B, N, H, W)
+ embeddings.append(x)
+
+ embeddings = torch.cat(embeddings, dim=0) # (T, N, H, W)
+ embeddings = embeddings.moveaxis(1, 3) # (T, H, W, N)
+ embeddings = embeddings.reshape(-1, self.final_grid_size[0] * self.final_grid_size[1], self.hidden_state_size) # (T, grid_size**2, N)
+ return embeddings
+
+
+@register_backbone("debug/encoded_labels", 64)
+class EncodedLabelsFeatures(FeatureExtractor):
+ """Encodes labels to 32 dimensions. Should work as a "perfect" feature extractor that uses GT labels as a sanity check."""
+ model_name = "debug/encoded_labels"
+
+ def __init__(
+ self,
+ image_size: tuple[int, int],
+ save_path: str | Path,
+ batch_size: int = 4,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ mode: PretrainedFeatsExtractionMode = "nearest_patch",
+ **kwargs,
+ ):
+ super().__init__(image_size, save_path, batch_size, device, mode)
+ self.input_size = 1024
+ self.final_grid_size = 32
+ self.n_channels = 1
+ self.hidden_state_size = 64
+
+ def _run_model(self, images, **kwargs) -> torch.Tensor:
+ """Extracts embeddings from the model."""
+ pass
+
+ def precompute_image_embeddings(self, images, **kwargs):
+ pass
+
+ def _encode_labels(self, labels):
+ """Encodes the labels to D dimensions."""
+ features = np.zeros((labels.shape[0], self.hidden_state_size), dtype=np.float32)
+ for i in range(labels.shape[0]):
+ label = labels[i]
+ features[i] = np.array([int(x) for x in np.binary_repr(label, width=self.hidden_state_size)], dtype=np.float32)
+
+ features = torch.from_numpy(features).to(self.device)
+ return features
+
+ def compute_region_features(self, labels=None, **kwargs):
+ return self._encode_labels(labels) # (n_labels, self.hidden_state_size)
+
+
+@register_backbone("debug/random", 128)
+class RandomFeatures(FeatureExtractor):
+ model_name = "debug/random"
+
+ def __init__(
+ self,
+ image_size: tuple[int, int],
+ save_path: str | Path,
+ batch_size: int = 4,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ mode: PretrainedFeatsExtractionMode = "nearest_patch",
+ seed: int = 42,
+ **kwargs,
+ ):
+ super().__init__(image_size, save_path, batch_size, device, mode)
+ self.input_size = 1024
+ # self.final_grid_size = self.orig_image_size
+ self.final_grid_size = 128
+ self.n_channels = 3
+ self.hidden_state_size = 128
+
+ self._seed = seed if seed is not None else 42
+ self.device = "cpu"
+ self._generator = torch.Generator(device="cpu").manual_seed(self._seed)
+
+ self.do_save = False
+
+ def _run_model(self, images, **kwargs) -> torch.Tensor:
+ """Extracts embeddings from the model."""
+ # Normal distribution
+ # return torch.randn(len(images), self.final_grid_size**2, self.hidden_state_size, generator=self._generator).to(self.device)
+ # Uniform distribution
+ feats = torch.rand(
+ len(images),
+ self.final_grid_size[0] * self.final_grid_size[1],
+ self.hidden_state_size,
+ generator=self._generator,
+ dtype=torch.float32
+ )
+ # feats = feats * 4 - 2 # [-2, 2]
+ return feats.to("cpu")
+
+
+FeatureExtractor._available_backbones = AVAILABLE_PRETRAINED_BACKBONES
+
+
+# Embeddings post-processing
+
+import pickle
+
+from sklearn.decomposition import PCA
+
+
+class EmbeddingsPCACompression:
+
+ def __init__(self, original_model_name: str, n_components: int = 15, save_path: str | Path | None = None):
+ self.original_model_name = original_model_name
+ self.n_components = n_components
+ self.save_path = save_path / "pca_model.pkl" if save_path is not None else None
+ self.pca = PCA(n_components=n_components)
+ self.max_frames = 1500
+ self.random_state = 42
+ self.pca.random_state = self.random_state
+ self.generator = np.random.default_rng(self.random_state)
+
+ @classmethod
+ def from_pretrained_cfg(cls, cfg: PretrainedFeatureExtractorConfig):
+ return cls(
+ original_model_name=cfg.model_name.replace("/", "-"),
+ n_components=cfg.pca_components,
+ save_path=cfg.pca_preprocessor_path
+ )
+
+ def fit(self, X: np.ndarray):
+ """Fits the PCA model to the embeddings."""
+ self.pca.fit(X)
+
+ if self.save_path is not None:
+ if not self.save_path.parents[0].exists():
+ self.save_path.parents[0].mkdir(parents=True, exist_ok=True)
+ with open(self.save_path, 'wb') as f:
+ pickle.dump(self.pca, f)
+
+ def transform(self, X: np.ndarray) -> np.ndarray:
+ """Transforms the embeddings using the fitted PCA model."""
+ return self.pca.transform(X)
+
+ def load_from_file(self, path: str | Path):
+ """Loads the PCA model from a file."""
+ if isinstance(path, str):
+ path = Path(path)
+ path = path / "pca_model.pkl" if path.suffix != ".pkl" else path
+ with open(path, 'rb') as f:
+ self.pca = pickle.load(f)
+ logger.info(f"Loaded PCA model from {path}.")
+
+ def fit_on_embeddings(self, embeddings_source_folders: list[str | Path]):
+ """Fits the PCA model to the embeddings loaded from a file."""
+ embeddings = []
+ N_samples = 0
+
+ embeddings_paths = []
+ for folder in embeddings_source_folders:
+ for file in Path(folder).rglob("*.npy"):
+ if self.original_model_name in file.name:
+ embeddings_paths.append(file)
+
+ if len(embeddings_paths) == 0:
+ return
+
+ embeddings_paths = self.generator.permutation(embeddings_paths)
+ logger.info(f"Fitting PCA model on {len(embeddings_paths)} embeddings files.")
+ logger.info("Files :")
+ for p in embeddings_paths:
+ logger.info(f" - {p}")
+ logger.info("*" * 50)
+
+ for path in embeddings_paths:
+ if isinstance(path, str):
+ path = Path(path)
+ if not path.exists():
+ raise FileNotFoundError(f"File {path} does not exist.")
+ emb = np.load(path)
+ N_samples += emb.shape[0]
+ if N_samples > self.max_frames:
+ logger.info(f"Amount of loaded frames exceeds {self.max_frames} limit for PCA computation.")
+ break
+ else:
+ embeddings.append(emb)
+
+ embeddings = np.concatenate(embeddings, axis=0)
+ embeddings = embeddings.reshape(-1, embeddings.shape[-1])
+ self.fit(embeddings)
+ logger.info(f"Fitted PCA model with {self.n_components} components on {N_samples} frames.")
+
\ No newline at end of file
diff --git a/trackastra/data/utils.py b/trackastra/data/utils.py
index a451876..5a37e78 100644
--- a/trackastra/data/utils.py
+++ b/trackastra/data/utils.py
@@ -1,6 +1,6 @@
import logging
import sys
-from pathlib import Path
+from pathlib import Path, WindowsPath
import numpy as np
import pandas as pd
@@ -9,9 +9,25 @@
import tifffile
from tqdm import tqdm
+from trackastra.data.pretrained_features import PretrainedFeatureExtractorConfig
+
logger = logging.getLogger(__name__)
+def make_hashable(obj):
+ if isinstance(obj, tuple | list):
+ return tuple(make_hashable(e) for e in obj)
+ elif isinstance(obj, Path | WindowsPath):
+ return obj.as_posix()
+ elif isinstance(obj, dict):
+ return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
+ elif isinstance(obj, PretrainedFeatureExtractorConfig):
+ cfg_dict = obj.to_dict()
+ return make_hashable(cfg_dict)
+ else:
+ return obj
+
+
def load_tiff_timeseries(
dir: Path,
dtype: str | type | None = None,
diff --git a/trackastra/data/wrfeat.py b/trackastra/data/wrfeat.py
index 2c72af7..3c57a5f 100644
--- a/trackastra/data/wrfeat.py
+++ b/trackastra/data/wrfeat.py
@@ -2,12 +2,15 @@
WindowedRegionFeatures (WRFeatures) is a class that holds regionprops features for a windowed track region.
"""
+from __future__ import annotations
+
import itertools
import logging
+from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import Iterable, Sequence
from functools import reduce
-from typing import Literal
+from typing import TYPE_CHECKING, ClassVar, Literal
import joblib
import numpy as np
@@ -16,9 +19,12 @@
from skimage.measure import regionprops, regionprops_table
from tqdm import tqdm
-from trackastra.data.utils import load_tiff_timeseries
+if TYPE_CHECKING:
+
+ from trackastra.data.pretrained_features import FeatureExtractor
logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
_PROPERTIES = {
"regionprops": (
@@ -34,7 +40,12 @@
"inertia_tensor",
"border_dist",
),
+ "regionprops_small": (
+ "area",
+ "inertia_tensor",
+ ),
}
+DEFAULT_PROPERTIES = "regionprops2"
def _filter_points(
@@ -67,8 +78,56 @@ def _border_dist(mask: np.ndarray, cutoff: float = 5):
return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist))
+def _border_dist_fast(mask: np.ndarray, cutoff: float = 5):
+ cutoff = int(cutoff)
+ border = np.ones(mask.shape, dtype=np.float32)
+ ndim = len(mask.shape)
+
+ for axis, size in enumerate(mask.shape):
+ # Create fade values for the band [0, cutoff)
+ band_vals = np.arange(cutoff, dtype=np.float32) / cutoff
+
+ # Build slices for the low border
+ low_slices = [slice(None)] * ndim
+ low_slices[axis] = slice(0, cutoff)
+ border_low = border[tuple(low_slices)]
+ border_low_vals = np.minimum(
+ border_low, band_vals[(...,) + (None,) * (ndim - axis - 1)]
+ )
+ border[tuple(low_slices)] = border_low_vals
+
+ # Build slices for the high border
+ high_slices = [slice(None)] * ndim
+ high_slices[axis] = slice(size - cutoff, size)
+ band_vals_rev = band_vals[::-1]
+ border_high = border[tuple(high_slices)]
+ border_high_vals = np.minimum(
+ border_high, band_vals_rev[(...,) + (None,) * (ndim - axis - 1)]
+ )
+ border[tuple(high_slices)] = border_high_vals
+
+ dist = 1 - border
+ return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist))
+
+# Features classes
+
+
class WRFeatures:
"""regionprops features for a windowed track region."""
+ PROPERTIES_DIMS: ClassVar = {
+ "regionprops": {
+ 2: 8,
+ 3: 12,
+ },
+ "regionprops2": {
+ 2: 7,
+ 3: 12,
+ },
+ "regionprops_small": {
+ 2: 5,
+ 3: 9,
+ },
+ }
def __init__(
self,
@@ -76,6 +135,7 @@ def __init__(
labels: np.ndarray,
timepoints: np.ndarray,
features: OrderedDict[np.ndarray],
+ properties: str = DEFAULT_PROPERTIES,
):
self.ndim = coords.shape[-1]
if self.ndim not in (2, 3):
@@ -83,9 +143,14 @@ def __init__(
self.coords = coords
self.labels = labels
- self.features = features.copy()
+ if features is None:
+ self.features = OrderedDict()
+ else:
+ self.features = features.copy()
self.timepoints = timepoints
-
+
+ self.properties = properties
+
def __repr__(self):
s = (
f"WindowRegionFeatures(ndim={self.ndim}, nregions={len(self.labels)},"
@@ -97,7 +162,25 @@ def __repr__(self):
@property
def features_stacked(self):
- return np.concatenate([v for k, v in self.features.items()], axis=-1)
+ if not self.features or (len(self.features) == 1 and "pretrained_feats" in self.features):
+ # logger.warning("No features to stack")
+ return None
+ feats = np.concatenate(
+ [v for k, v in self.features.items() if k != "pretrained_feats"],
+ axis=-1
+ )
+ # raise if any NaNs in features
+ return feats
+
+ @property
+ def pretrained_feats(self):
+ # for compatibility with WRPretrainedFeatures
+ if "pretrained_feats" in self.features:
+ return self.features["pretrained_feats"]
+ # return self.features["pretrained_feats"] / np.linalg.norm(
+ # self.features["pretrained_feats"], axis=-1, keepdims=True
+ # )
+ return None
def __len__(self):
return len(self.labels)
@@ -107,15 +190,22 @@ def __getitem__(self, key):
return self.features[key]
else:
raise KeyError(f"Key {key} not found in features")
+
+ @property
+ def features_dims(self):
+ """Returns the number of features for each property."""
+ if self.properties not in self.PROPERTIES_DIMS:
+ raise ValueError(f"Unknown feature type {self.properties}")
+ return self.PROPERTIES_DIMS[self.properties][self.ndim]
@classmethod
- def concat(cls, feats: Sequence["WRFeatures"]) -> "WRFeatures":
+ def concat(cls, feats: Sequence[WRFeatures]) -> WRFeatures:
"""Concatenate multiple WRFeatures into a single one."""
if len(feats) == 0:
raise ValueError("Cannot concatenate empty list of features")
return reduce(lambda x, y: x + y, feats)
- def __add__(self, other: "WRFeatures") -> "WRFeatures":
+ def __add__(self, other: WRFeatures) -> WRFeatures:
"""Concatenate two WRFeatures."""
if self.ndim != other.ndim:
raise ValueError("Cannot concatenate features of different dimensions")
@@ -131,23 +221,24 @@ def __add__(self, other: "WRFeatures") -> "WRFeatures":
for k, v in self.features.items()
)
- return WRFeatures(
+ return self.__class__(
coords=coords, labels=labels, timepoints=timepoints, features=features
)
-
- @classmethod
- def from_mask_img(
- cls,
- mask: np.ndarray,
- img: np.ndarray,
- properties="regionprops2",
- t_start: int = 0,
- ):
+
+ @staticmethod
+ def get_regionprops_features(properties, mask, img, t_start=0):
+ """Extracts regionprops features from a mask and image."""
+ img = np.asarray(img)
+ mask = np.asarray(mask)
_ntime, ndim = mask.shape[0], mask.ndim - 1
if ndim not in (2, 3):
raise ValueError("Only 2D or 3D data is supported")
- properties = tuple(_PROPERTIES[properties])
+ if properties is None:
+ properties = ()
+ else:
+ properties = tuple(_PROPERTIES[properties])
+
if "label" in properties or "centroid" in properties:
raise ValueError(
f"label and centroid should not be in properties {properties}"
@@ -169,7 +260,7 @@ def from_mask_img(
_df["timepoint"] = i + t_start
if use_border_dist:
- _df["border_dist"] = _border_dist(y)
+ _df["border_dist"] = _border_dist_fast(y)
dfs.append(_df)
df = pd.concat(dfs)
@@ -180,6 +271,24 @@ def from_mask_img(
timepoints = df["timepoint"].values.astype(np.int32)
labels = df["label"].values.astype(np.int32)
coords = df[[f"centroid-{i}" for i in range(ndim)]].values.astype(np.float32)
+
+ # if any NaNs in features, raise
+ if df.isnull().values.any():
+ raise ValueError("NaNs found in features DataFrame")
+
+ return df, coords, labels, timepoints, properties
+
+ @classmethod
+ def from_mask_img(
+ cls,
+ mask: np.ndarray,
+ img: np.ndarray,
+ properties: str = DEFAULT_PROPERTIES,
+ t_start: int = 0,
+ ):
+ df, coords, labels, timepoints, properties = cls.get_regionprops_features(
+ properties, mask, img, t_start=t_start
+ )
features = OrderedDict(
(
@@ -197,26 +306,179 @@ def from_mask_img(
)
return cls(
- coords=coords, labels=labels, timepoints=timepoints, features=features
+ coords=coords, labels=labels, timepoints=timepoints, features=features, properties=properties
)
-# augmentations
+class WRPretrainedFeatures(WRFeatures):
+ """WindowedRegion with features from pre-trained models."""
+
+ def __init__(
+ self,
+ coords: np.ndarray,
+ labels: np.ndarray,
+ timepoints: np.ndarray,
+ features: OrderedDict[np.ndarray],
+ additional_properties: str | None = None
+ ):
+ super().__init__(coords, labels, timepoints, features)
+ self.additional_properties = additional_properties
+
+ @property
+ def features_stacked(self):
+ if not self.features or (len(self.features) == 1 and "pretrained_feats" in self.features):
+ # logger.warning("No features to stack")
+ return None
+ feats = np.concatenate(
+ [v for k, v in self.features.items() if k != "pretrained_feats"],
+ axis=-1
+ )
+ # raise if any NaNs in features
+ return feats
+
+ @property
+ def pretrained_feats(self):
+ return super().pretrained_feats
+ # if "pretrained_feats" in self.features:
+ # return self.features["pretrained_feats"]
+ # return None
+
+ @classmethod
+ def from_mask_img(
+ cls,
+ img: np.ndarray,
+ mask: np.ndarray,
+ feature_extractor: FeatureExtractor,
+ t_start: int = 0,
+ additional_properties: str | None = None,
+ # embeddings: torch.Tensor | None = None,
+ ) -> WRPretrainedFeatures:
+
+ ndim = img.ndim - 1
+ if ndim != 2:
+ raise ValueError("Only 2D data is supported")
+
+ df, coords, labels, timepoints, properties = cls.get_regionprops_features(
+ additional_properties, mask, img, t_start=t_start
+ )
+ # if embeddings is None:
+ _, features = feature_extractor.extract_embedding(mask, timepoints, labels, coords)
+ # else:
+ # _, features = feature_extractor.extract_embedding(mask, timepoints, labels, coords, embs=embeddings)
+ features = features.detach().cpu().numpy()
+ feats_dict = OrderedDict(pretrained_feats=features)
+ # Add additional features similarly to WRFeatures if any
+ if additional_properties is not None:
+ for p in properties:
+ feats_dict[p] = np.stack(
+ [
+ df[c].values.astype(np.float32)
+ for c in df.columns
+ if c.startswith(p)
+ ],
+ axis=-1,
+ )
+
+ return cls(
+ coords=coords, labels=labels, timepoints=timepoints, features=feats_dict, additional_properties=additional_properties
+ )
+
+
+class WRAugPretrainedFeatures(WRPretrainedFeatures):
+
+ def __init__(
+ self,
+ coords: np.ndarray,
+ labels: np.ndarray,
+ timepoints: np.ndarray,
+ features: OrderedDict[np.ndarray],
+ additional_properties: str | None = None,
+ ):
+
+ super().__init__(coords, labels, timepoints, features, additional_properties)
+ self.ndim = coords.shape[-1]
+ if self.ndim != 2:
+ raise ValueError("Only 2D data is supported")
+
+ def __len__(self):
+ return super().__len__()
+
+ @classmethod
+ def from_window(cls, features, coords, timepoints, labels):
+ """Build a WRAugPretrainedFeatures from a window.
+
+ Args:
+ features (np.ndarray): The features to use.
+ coords (np.ndarray): The coordinates to use.
+ timepoints (np.ndarray): The timepoints to use.
+ labels (np.ndarray): The labels to use.
+ """
+ # coords = coords[:, 1:]
+ return cls(
+ coords=coords,
+ labels=labels,
+ timepoints=timepoints,
+ features=features,
+ )
+
+ def to_window(self):
+ """Convert the features to a window."""
+ coords = np.concatenate((self.timepoints[:, None], self.coords), axis=-1)
+
+ if len(self.features) == 1 and "pretrained_feats" in self.features.keys():
+ feats = None
+ else:
+ feats = np.concatenate(
+ [v for k, v in self.features.items() if k != "pretrained_feats"],
+ axis=-1
+ )
+ pretrained_feats = self.features["pretrained_feats"]
+ return feats, pretrained_feats, coords, self.timepoints, self.labels
+
+ def to_dict(self):
+ """Convert the features to a dictionary."""
+ res = {}
+ for i, (t, lab) in enumerate(zip(self.timepoints, self.labels)):
+ t = int(t)
+ lab = int(lab)
+ if t not in res:
+ res[t] = {}
+ feat_dict = {}
+ for k, v in self.features.items():
+ feat_dict[k] = v[i]
+ res[t][lab] = {
+ "coords": self.coords[i],
+ "features": dict(feat_dict),
+ }
+ return res
+
+
+# Augmentations
class WRRandomCrop:
- """windowed region random crop augmentation."""
+ """windowed region random crop augmentation.
+
+ Affected properties:
+ - "coords"
+ - "labels"
+ - "timepoints"
+ - "features"
+ """
+ return_type = WRFeatures
def __init__(
self,
crop_size: int | tuple[int] | None = None,
ndim: int = 2,
+ return_type=WRFeatures,
) -> None:
"""crop_size: tuple of int
can be tuple of length 1 (all dimensions)
of length ndim (y,x,...)
of length 2*ndim (y1,y2, x1,x2, ...).
"""
+ self.return_type = return_type
if isinstance(crop_size, int):
crop_size = (crop_size,) * 2 * ndim
elif isinstance(crop_size, Iterable):
@@ -258,45 +520,84 @@ def __call__(self, features: WRFeatures):
)
idx = _filter_points(points, shape=crop_size, origin=corner)
-
+ feats = OrderedDict(
+ (k, v[idx]) for k, v in features.features.items()
+ )
return (
- WRFeatures(
+ self.return_type(
coords=points[idx],
labels=features.labels[idx],
timepoints=features.timepoints[idx],
- features=OrderedDict((k, v[idx]) for k, v in features.features.items()),
+ features=feats,
),
idx,
)
-class WRBaseAugmentation:
+class WRBaseAugmentation(ABC):
+ """Base class for windowed region augmentations."""
+ return_type = WRFeatures
+
def __init__(self, p: float = 0.5) -> None:
self._p = p
self._rng = np.random.RandomState()
def __call__(self, features: WRFeatures):
+ # logger.debug(f"Before augmentation: {self.__class__.__name__}, return_type={self.return_type}")
if self._rng.rand() > self._p or len(features) == 0:
return features
- return self._augment(features)
+ feats = self._augment(features)
+ # logger.debug(f"After augmentation: {self.__class__.__name__}, return_type={self.return_type}")
+ self.check_features(feats)
+ return feats
+ @abstractmethod
def _augment(self, features: WRFeatures):
raise NotImplementedError()
+
+ def check_features(self, features: WRFeatures):
+ """Check if features are valid."""
+ if not isinstance(features, self.return_type):
+ raise ValueError(f"Expected {self.return_type}, got {type(features)}")
+ if len(features) == 0:
+ raise ValueError("Empty features")
+
+ for k, f in features.features.items():
+ if np.any(np.isnan(f)):
+ logger.warning(f"NaNs found in {k} after {self.__class__.__name__} augmentation")
+ if np.any(np.isinf(f)):
+ logger.warning(f"Infs found in {k} after {self.__class__.__name__} augmentation")
+ # if np.all(np.all(f == 0, axis=-1)):
+ # logger.warning(f"Empty {k} after {self.__class__.__name__} augmentation")
class WRRandomFlip(WRBaseAugmentation):
+ """Random flip augmentation.
+
+ Affected properties:
+ - "area"
+ - "equivalent_diameter_area"
+ - "inertia_tensor"
+ """
def _augment(self, features: WRFeatures):
ndim = features.ndim
flip = self._rng.randint(0, 2, features.ndim)
+ M = np.eye(ndim)
points = features.coords.copy()
for i, f in enumerate(flip):
if f == 1:
points[:, ndim - i - 1] *= -1
- return WRFeatures(
+ M[i, i] = -1
+
+ feats = OrderedDict(
+ (k, _transform_affine(k, v, M)) for k, v in features.features.items()
+ )
+
+ return self.return_type(
coords=points,
labels=features.labels,
timepoints=features.timepoints,
- features=features.features,
+ features=feats,
)
@@ -323,18 +624,22 @@ def _rotation_matrix(theta: float):
def _transform_affine(k: str, v: np.ndarray, M: np.ndarray):
ndim = len(M)
+
if k == "area":
- v = np.linalg.det(M) * v
+ v = np.abs(np.linalg.det(M)) * v
elif k == "equivalent_diameter_area":
- v = np.linalg.det(M) ** (1 / len(M)) * v
-
+ # v = np.linalg.det(M) ** (1 / len(M)) * v
+ v = np.abs(np.linalg.det(M)) ** (1 / len(M)) * v
+ # TODO check the behavior of equivalent_diameter_area in 3D regionprops
elif k == "inertia_tensor":
# v' = M * v * M^T
- v = v.reshape(-1, ndim, ndim)
+ v = v.reshape(-1, ndim, ndim)
+
# v * M^T
v = np.einsum("ijk, mk -> ijm", v, M)
# M * v
v = np.einsum("ij, kjm -> kim", M, v)
+
v = v.reshape(-1, ndim * ndim)
elif k in (
"intensity_mean",
@@ -344,28 +649,47 @@ def _transform_affine(k: str, v: np.ndarray, M: np.ndarray):
"border_dist",
):
pass
+ elif k == "pretrained_feats":
+ pass
else:
raise ValueError(f"Don't know how to affinely transform {k}")
+
+ if np.isnan(v).any():
+ logger.error(f"NaNs found in {k} after affine transformation")
+
return v
class WRRandomAffine(WRBaseAugmentation):
+ """Random affine transformation augmentation.
+
+ Affected properties:
+ - "area"
+ - "equivalent_diameter_area"
+ - "inertia_tensor"
+ - "coords"
+ """
def __init__(
self,
degrees: float = 10,
scale: float = (0.9, 1.1),
shear: float = (0.1, 0.1),
p: float = 0.5,
+ scale_isotropic: float = (1., 1.),
):
super().__init__(p)
self.degrees = degrees if degrees is not None else 0
self.scale = scale if scale is not None else (1, 1)
self.shear = shear if shear is not None else (0, 0)
-
+ self.scale_isotropic = scale_isotropic
+
def _augment(self, features: WRFeatures):
degrees = self._rng.uniform(-self.degrees, self.degrees) / 180 * np.pi
+ scale_iso = self._rng.uniform(*self.scale_isotropic)
scale = self._rng.uniform(*self.scale, 3)
+ scale = scale * scale_iso
+
shy = self._rng.uniform(-self.shear[0], self.shear[0])
shx = self._rng.uniform(-self.shear[1], self.shear[1])
@@ -383,7 +707,7 @@ def _augment(self, features: WRFeatures):
(k, _transform_affine(k, v, self._M)) for k, v in features.features.items()
)
- return WRFeatures(
+ return self.return_type(
coords=points,
labels=features.labels,
timepoints=features.timepoints,
@@ -392,6 +716,11 @@ def _augment(self, features: WRFeatures):
class WRRandomBrightness(WRBaseAugmentation):
+ """random brightness augmentation.
+
+ Affected properties:
+ - "intensity"
+ """
def __init__(
self,
scale: tuple[float] = (0.5, 2.0),
@@ -414,7 +743,7 @@ def _augment(self, features: WRFeatures):
v = v * scale + shift
key_vals.append((k, v))
feats = OrderedDict(key_vals)
- return WRFeatures(
+ return self.return_type(
coords=features.coords,
labels=features.labels,
timepoints=features.timepoints,
@@ -423,6 +752,11 @@ def _augment(self, features: WRFeatures):
class WRRandomOffset(WRBaseAugmentation):
+ """Random offset augmentation.
+
+ Affected properties:
+ - "coords"
+ """
def __init__(self, offset: float = (-3, 3), p: float = 0.5):
super().__init__(p)
self.offset = offset
@@ -431,7 +765,8 @@ def _augment(self, features: WRFeatures):
offset = self._rng.uniform(*self.offset, features.coords.shape)
coords = features.coords + offset
- return WRFeatures(
+
+ return self.return_type(
coords=coords,
labels=features.labels,
timepoints=features.timepoints,
@@ -440,7 +775,11 @@ def _augment(self, features: WRFeatures):
class WRRandomMovement(WRBaseAugmentation):
- """random global linear shift."""
+ """Random global linear shift.
+
+ Affected properties:
+ - "coords"
+ """
def __init__(self, offset: float = (-10, 10), p: float = 0.5):
super().__init__(p)
self.offset = offset
@@ -451,7 +790,7 @@ def _augment(self, features: WRFeatures):
offset = (features.timepoints[:, None] - tmin) * base_offset[None]
coords = features.coords + offset
- return WRFeatures(
+ return self.return_type(
coords=coords,
labels=features.labels,
timepoints=features.timepoints,
@@ -460,53 +799,171 @@ def _augment(self, features: WRFeatures):
class WRAugmentationPipeline:
- def __init__(self, augmentations: Sequence[WRBaseAugmentation]):
+ def __init__(self, augmentations: Sequence[WRBaseAugmentation], return_type=None):
self.augmentations = augmentations
+ self.return_type = return_type if return_type is not None else WRFeatures
+ logger.debug(f"Augmentation pipeline return type: {self.return_type}")
+ for aug in self.augmentations:
+ aug.return_type = self.return_type
def __call__(self, feats: WRFeatures):
+ # logger.debug(f"Applying {len(self.augmentations)} augmentations")
for aug in self.augmentations:
+ # logger.debug(f"Applying {aug.__class__.__name__} augmentation")
+ aug.return_type = self.return_type
+ # logger.debug(f"Augmentation return type: {aug.return_type}")
feats = aug(feats)
+
+ # logger.debug(f"Coords : {feats.coords}")
+
return feats
+# Factory functions
+
+
+class AugmentationFactory:
+ default_args: ClassVar = {
+ "flip": {"p": 0.5},
+ "affine": {
+ "p": 0.8,
+ "degrees": 180,
+ "scale": (0.5, 2),
+ "shear": (0.1, 0.1),
+ },
+ "brightness": {"p": 0.8},
+ "offset": {"p": 0.8, "offset": (-3, 3)},
+ "movement": {"offset": (-10, 10), "p": 0.3},
+ }
+
+ @staticmethod
+ def create_augmentation_pipeline(augment: int, return_type: str = WRFeatures):
+ if augment == 0:
+ return None
+ elif augment == 1:
+ return WRAugmentationPipeline(
+ [
+ WRRandomFlip(**AugmentationFactory.default_args["flip"]),
+ WRRandomAffine(
+ **AugmentationFactory.default_args["affine"]
+ ),
+ # wrfeat.WRRandomBrightness(p=0.8, factor=(0.5, 2.0)),
+ # wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
+ ],
+ return_type=return_type,
+ )
+ elif augment == 2:
+ return WRAugmentationPipeline(
+ [
+ WRRandomFlip(**AugmentationFactory.default_args["flip"]),
+ WRRandomAffine(**AugmentationFactory.default_args["affine"]),
+ WRRandomBrightness(**AugmentationFactory.default_args["brightness"]),
+ WRRandomOffset(**AugmentationFactory.default_args["offset"]),
+ ],
+ return_type=return_type,
+ )
+ elif augment == 3:
+ return WRAugmentationPipeline(
+ [
+ WRRandomFlip(**AugmentationFactory.default_args["flip"]),
+ WRRandomAffine(
+ p=0.8,
+ degrees=180,
+ scale=(0.9, 1.1),
+ shear=(0.1, 0.1),
+ scale_isotropic=(0.5, 2.0)
+ ),
+ WRRandomBrightness(**AugmentationFactory.default_args["brightness"]),
+ WRRandomMovement(**AugmentationFactory.default_args["movement"]),
+ WRRandomOffset(**AugmentationFactory.default_args["offset"]),
+ ],
+ return_type=return_type,
+ )
+ elif augment == 4:
+ return WRAugmentationPipeline(
+ [
+ WRRandomAffine(
+ p=0.8,
+ degrees=180,
+ scale=(0.9, 1.1),
+ shear=(0.1, 0.1),
+ scale_isotropic=(0.5, 2.0)
+ ),
+ # WRRandomMovement(**AugmentationFactory.default_args["movement"]),
+ # WRRandomOffset(**AugmentationFactory.default_args["offset"]),
+ ],
+ return_type=return_type,
+ )
+ else:
+ raise ValueError(f"Invalid augment level {augment}")
+
+ @staticmethod
+ def create_cropper(crop_size: tuple[int], ndim: int, return_type=WRFeatures):
+ if crop_size is not None:
+ return WRRandomCrop(
+ crop_size=crop_size,
+ ndim=ndim,
+ return_type=return_type,
+ )
+ return None
+
def get_features(
detections: np.ndarray,
imgs: np.ndarray | None = None,
- features: Literal["none", "wrfeat"] = "wrfeat",
+ features_type: Literal["none", "wrfeat", "pretrained_feats", "pretrained_feats_aug"] = "wrfeat",
ndim: int = 2,
n_workers=0,
progbar_class=tqdm,
+ feature_extractor: FeatureExtractor | None = None,
) -> list[WRFeatures]:
+ """Extracts features from detections and images."""
detections = _check_dimensions(detections, ndim)
imgs = _check_dimensions(imgs, ndim)
logger.info(f"Extracting features from {len(detections)} detections")
- if n_workers > 0:
- features = joblib.Parallel(n_jobs=n_workers)(
- joblib.delayed(WRFeatures.from_mask_img)(
- # New axis for time component
- mask=mask[np.newaxis, ...],
- img=img[np.newaxis, ...],
- t_start=t,
+ if features_type in ["none", "wrfeat"]:
+ if n_workers > 0:
+ logger.info(f"Using {n_workers} processes for feature extraction")
+ features = joblib.Parallel(n_jobs=n_workers, backend="loky")(
+ joblib.delayed(WRFeatures.from_mask_img)(
+ # New axis for time component
+ mask=mask[np.newaxis, ...].copy(),
+ img=img[np.newaxis, ...].copy(),
+ t_start=t,
+ )
+ for t, (mask, img) in progbar_class(
+ enumerate(zip(detections, imgs)),
+ total=len(imgs),
+ desc="Extracting features",
+ )
)
- for t, (mask, img) in progbar_class(
- enumerate(zip(detections, imgs)),
- total=len(imgs),
- desc="Extracting features",
+ else:
+ logger.info("Using single process for feature extraction")
+ features = tuple(
+ WRFeatures.from_mask_img(
+ mask=mask[np.newaxis, ...],
+ img=img[np.newaxis, ...],
+ t_start=t,
+ )
+ for t, (mask, img) in progbar_class(
+ enumerate(zip(detections, imgs)),
+ total=len(imgs),
+ desc="Extracting features",
+ )
)
- )
+ if features_type == "none":
+ for f in features:
+ f.features = OrderedDict()
+ elif features_type == "pretrained_feats" or features_type == "pretrained_feats_aug":
+ feature_extractor.precompute_image_embeddings(imgs)
+ features = [
+ WRPretrainedFeatures.from_mask_img(
+ img=img[np.newaxis], mask=mask[np.newaxis], feature_extractor=feature_extractor, t_start=t, additional_properties=feature_extractor.additional_features,
+ )
+ for t, (mask, img) in enumerate(zip(detections, imgs))
+ ]
else:
- logger.info("Using single process for feature extraction")
- features = tuple(
- WRFeatures.from_mask_img(
- mask=mask[np.newaxis, ...],
- img=img[np.newaxis, ...],
- t_start=t,
- )
- for t, (mask, img) in progbar_class(
- enumerate(zip(detections, imgs)),
- total=len(imgs),
- desc="Extracting features",
- )
+ raise ValueError(
+ f"Unknown feature extraction method {features}. Available: 'none', 'wrfeat' or 'pretrained_feats'."
)
return features
@@ -529,6 +986,7 @@ def _check_dimensions(x: np.ndarray, ndim: int):
def build_windows(
features: list[WRFeatures], window_size: int, progbar_class=tqdm
) -> list[dict]:
+ """Builds windows from a list of WRFeatures."""
windows = []
for t1, t2 in progbar_class(
zip(range(0, len(features)), range(window_size, len(features) + 1)),
@@ -543,13 +1001,14 @@ def build_windows(
if len(feat) == 0:
coords = np.zeros((0, feat.ndim), dtype=int)
-
+ pt_feats = feat.pretrained_feats if feat.pretrained_feats is not None else None
w = dict(
coords=coords,
t1=t1,
labels=labels,
timepoints=timepoints,
features=feat.features_stacked,
+ pretrained_features=pt_feats,
)
windows.append(w)
@@ -558,23 +1017,32 @@ def build_windows(
if __name__ == "__main__":
- imgs = load_tiff_timeseries(
- # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01",
- "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01",
- )
- masks = load_tiff_timeseries(
- # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01_GT/TRA",
- "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01_GT/TRA",
- dtype=int,
- )
-
- features = get_features(detections=masks, imgs=imgs, ndim=3)
- windows = build_windows(features, window_size=4)
-
-
-# if __name__ == "__main__":
-# y = np.zeros((1, 100, 100), np.uint8)
-# y[:, 20:40, 20:60] = 1
-# x = y + np.random.normal(0, 0.1, y.shape)
-
-# f = WRFeatures.from_mask_img(y, x, properties=("intensity_mean", "area"))
+ # imgs = load_tiff_timeseries(
+ # # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01",
+ # "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01",
+ # )
+ # masks = load_tiff_timeseries(
+ # # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01_GT/TRA",
+ # "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01_GT/TRA",
+ # dtype=int,
+ # )
+
+ # features = get_features(detections=masks, imgs=imgs, ndim=3)
+ # windows = build_windows(features, window_size=4)
+
+ y = np.zeros((1, 100, 100), np.uint8)
+ y[:, 20:40, 20:60] = 1
+ x = y + np.random.normal(0, 0.1, y.shape)
+
+ f = WRFeatures.from_mask_img(y, x, properties='regionprops2')
+
+ # f = WRFeatures.from_pretrained(y, x)
+
+ augmenter = WRAugmentationPipeline([
+ WRRandomAffine(degrees=10, scale=(0.9, 1.1), shear=(0.1, 0.1), p=0.5),
+ WRRandomBrightness(scale=(0.5, 2.0), shift=(-0.1, 0.1), p=0.5),
+ WRRandomOffset(offset=(-3, 3), p=0.5),
+ WRRandomMovement(offset=(-10, 10), p=0.5),
+ ])
+
+ f2 = augmenter(f)
\ No newline at end of file
diff --git a/trackastra/model/model.py b/trackastra/model/model.py
index 3b2bb57..2c07b8f 100644
--- a/trackastra/model/model.py
+++ b/trackastra/model/model.py
@@ -32,7 +32,7 @@ def __init__(
dropout=0.1,
cutoff_spatial: int = 256,
window: int = 16,
- positional_bias: Literal["bias", "rope", "none"] = "bias",
+ positional_bias: Literal["bias", "rope", "none"] = "rope",
positional_bias_n_spatial: int = 32,
attn_dist_mode: str = "v0",
):
@@ -135,137 +135,41 @@ def forward(
return x
-# class BidirectionalRelativePositionalAttention(RelativePositionalAttention):
-# def forward(
-# self,
-# query1: torch.Tensor,
-# query2: torch.Tensor,
-# coords: torch.Tensor,
-# padding_mask: torch.Tensor = None,
-# ):
-# B, N, D = query1.size()
-# q1 = self.q_pro(query1) # (B, N, D)
-# q2 = self.q_pro(query2) # (B, N, D)
-# v1 = self.v_pro(query1) # (B, N, D)
-# v2 = self.v_pro(query2) # (B, N, D)
-
-# # (B, nh, N, hs)
-# q1 = q1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
-# v1 = v1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
-# q2 = q2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
-# v2 = v2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
-
-# attn_mask = torch.zeros(
-# (B, self.n_head, N, N), device=query1.device, dtype=q1.dtype
-# )
-
-# # add negative value but not too large to keep mixed precision loss from becoming nan
-# attn_ignore_val = -1e3
-
-# # spatial cutoff
-# yx = coords[..., 1:]
-# spatial_dist = torch.cdist(yx, yx)
-# spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1)
-# attn_mask.masked_fill_(spatial_mask, attn_ignore_val)
-
-# # dont add positional bias to self-attention if coords is None
-# if coords is not None:
-# if self._mode == "bias":
-# attn_mask = attn_mask + self.pos_bias(coords)
-# elif self._mode == "rope":
-# q1, q2 = self.rot_pos_enc(q1, q2, coords)
-# else:
-# pass
-
-# dist = torch.cdist(coords, coords, p=2)
-# attn_mask += torch.exp(-0.1 * dist.unsqueeze(1))
-
-# # if given key_padding_mask = (B,N) then ignore those tokens (e.g. padding tokens)
-# if padding_mask is not None:
-# ignore_mask = torch.logical_or(
-# padding_mask.unsqueeze(1), padding_mask.unsqueeze(2)
-# ).unsqueeze(1)
-# attn_mask.masked_fill_(ignore_mask, attn_ignore_val)
-
-# self.attn_mask = attn_mask.clone()
-
-# y1 = nn.functional.scaled_dot_product_attention(
-# q1,
-# q2,
-# v1,
-# attn_mask=attn_mask,
-# dropout_p=self.dropout if self.training else 0,
-# )
-# y2 = nn.functional.scaled_dot_product_attention(
-# q2,
-# q1,
-# v2,
-# attn_mask=attn_mask,
-# dropout_p=self.dropout if self.training else 0,
-# )
-
-# y1 = y1.transpose(1, 2).contiguous().view(B, N, D)
-# y1 = self.proj(y1)
-# y2 = y2.transpose(1, 2).contiguous().view(B, N, D)
-# y2 = self.proj(y2)
-# return y1, y2
-
-
-# class BidirectionalCrossAttention(nn.Module):
-# def __init__(
-# self,
-# coord_dim: int = 2,
-# d_model=256,
-# num_heads=4,
-# dropout=0.1,
-# window: int = 16,
-# cutoff_spatial: int = 256,
-# positional_bias: Literal["bias", "rope", "none"] = "bias",
-# positional_bias_n_spatial: int = 32,
-# ):
-# super().__init__()
-# self.positional_bias = positional_bias
-# self.attn = BidirectionalRelativePositionalAttention(
-# coord_dim,
-# d_model,
-# num_heads,
-# cutoff_spatial=cutoff_spatial,
-# n_spatial=positional_bias_n_spatial,
-# cutoff_temporal=window,
-# n_temporal=window,
-# dropout=dropout,
-# mode=positional_bias,
-# )
-
-# self.mlp = FeedForward(d_model)
-# self.norm1 = nn.LayerNorm(d_model)
-# self.norm2 = nn.LayerNorm(d_model)
-
-# def forward(
-# self,
-# x: torch.Tensor,
-# y: torch.Tensor,
-# coords: torch.Tensor,
-# padding_mask: torch.Tensor = None,
-# ):
-# x = self.norm1(x)
-# y = self.norm1(y)
-
-# # cross attention
-# # setting coords to None disables positional bias
-# x2, y2 = self.attn(
-# x,
-# y,
-# coords=coords if self.positional_bias else None,
-# padding_mask=padding_mask,
-# )
-# # print(torch.norm(x2).item()/torch.norm(x).item())
-# x = x + x2
-# x = x + self.mlp(self.norm2(x))
-# y = y + y2
-# y = y + self.mlp(self.norm2(y))
-
-# return x, y
+class LearnedRoPERotation(nn.Module):
+ def __init__(self, coord_dim, feature_dim, rope_dim=None):
+ super().__init__()
+ self.rope_dim = rope_dim or feature_dim
+ # MLP to predict rotation angle(s) from coords
+ self.angle_mlp = nn.Sequential(
+ nn.Linear(coord_dim, 64),
+ nn.ReLU(),
+ nn.Linear(64, self.rope_dim // 2) # one angle per feature pair
+ )
+
+ def forward(self, features, coords):
+ """features: (B, N, D)
+ coords: (B, N, coord_dim).
+ """
+ B, N, D = features.shape
+ assert D % 2 == 0, "Feature dim must be even for RoPE."
+ rope_dim = self.rope_dim
+
+ # Predict angles (B, N, rope_dim//2)
+ angles = self.angle_mlp(coords[..., :]) # you can select which coords to use
+ # Expand to (B, N, rope_dim)
+ cos = torch.cos(angles)
+ sin = torch.sin(angles)
+ # Prepare features for rotation
+ f = features[..., :rope_dim].reshape(B, N, -1, 2)
+ x, y = f[..., 0], f[..., 1]
+ # Apply rotation
+ x_rot = x * cos - y * sin
+ y_rot = x * sin + y * cos
+ rotated = torch.stack([x_rot, y_rot], dim=-1).reshape(B, N, rope_dim)
+ # Concatenate with the rest of the features if needed
+ if rope_dim < D:
+ rotated = torch.cat([rotated, features[..., rope_dim:]], dim=-1)
+ return rotated
class TrackingTransformer(torch.nn.Module):
@@ -273,6 +177,8 @@ def __init__(
self,
coord_dim: int = 3,
feat_dim: int = 0,
+ pretrained_feat_dim: int = 0,
+ reduced_pretrained_feat_dim: int = 128,
d_model: int = 128,
nhead: int = 4,
num_encoder_layers: int = 4,
@@ -288,12 +194,16 @@ def __init__(
"none", "linear", "softmax", "quiet_softmax"
] = "quiet_softmax",
attn_dist_mode: str = "v0",
+ disable_xy_coords: bool = False,
+ disable_all_coords: bool = False,
):
super().__init__()
-
+
self.config = dict(
coord_dim=coord_dim,
feat_dim=feat_dim,
+ pretrained_feat_dim=pretrained_feat_dim,
+ reduced_pretrained_feat_dim=reduced_pretrained_feat_dim,
pos_embed_per_dim=pos_embed_per_dim,
d_model=d_model,
nhead=nhead,
@@ -307,15 +217,33 @@ def __init__(
feat_embed_per_dim=feat_embed_per_dim,
causal_norm=causal_norm,
attn_dist_mode=attn_dist_mode,
+ disable_xy_coords=disable_xy_coords,
+ disable_all_coords=disable_all_coords,
)
-
- # TODO remove, alredy present in self.config
- # self.window = window
- # self.feat_dim = feat_dim
- # self.coord_dim = coord_dim
-
+
+ # TODO temp attr, add as train config arg
+ if pretrained_feat_dim > 0:
+ self.reduced_pretrained_feat_dim = reduced_pretrained_feat_dim
+ else:
+ self.reduced_pretrained_feat_dim = 0
+ self._return_norms = True
+ self.norms = {}
+
+ self._disable_xy_coords = disable_xy_coords
+ self._disable_all_coords = disable_all_coords
+
+ if self._disable_all_coords:
+ coords_proj_dims = 0
+ elif self._disable_xy_coords:
+ coords_proj_dims = pos_embed_per_dim
+ else:
+ coords_proj_dims = (1 + coord_dim) * pos_embed_per_dim
+
+ feats_proj_dims = feat_dim * feat_embed_per_dim
+
self.proj = nn.Linear(
- (1 + coord_dim) * pos_embed_per_dim + feat_dim * feat_embed_per_dim, d_model
+ coords_proj_dims + feats_proj_dims + self.reduced_pretrained_feat_dim,
+ d_model
)
self.norm = nn.LayerNorm(d_model)
@@ -363,17 +291,36 @@ def __init__(
)
else:
self.feat_embed = nn.Identity()
-
- self.pos_embed = PositionalEncoding(
- cutoffs=(window,) + (spatial_pos_cutoff,) * coord_dim,
- n_pos=(pos_embed_per_dim,) * (1 + coord_dim),
- )
+
+ if pretrained_feat_dim > 0:
+ self.ptfeat_proj = nn.Sequential(
+ nn.Linear(pretrained_feat_dim, self.reduced_pretrained_feat_dim),
+ )
+ self.ptfeat_norm = nn.LayerNorm(self.reduced_pretrained_feat_dim)
+ else:
+ self.ptfeat_proj = nn.Identity()
+ self.ptfeat_norm = nn.Identity()
+
+ if self._disable_all_coords:
+ self.pos_embed = nn.Identity()
+
+ elif self._disable_xy_coords:
+ self.pos_embed = PositionalEncoding(
+ cutoffs=(window,),
+ n_pos=(pos_embed_per_dim,),
+ )
+ else:
+ self.pos_embed = PositionalEncoding(
+ cutoffs=(window,) + (spatial_pos_cutoff,) * coord_dim,
+ n_pos=(pos_embed_per_dim,) * (1 + coord_dim),
+ )
# self.pos_embed = NoPositionalEncoding(d=pos_embed_per_dim * (1 + coord_dim))
- def forward(self, coords, features=None, padding_mask=None):
+ def forward(self, coords, features=None, pretrained_features=None, padding_mask=None):
assert coords.ndim == 3 and coords.shape[-1] in (3, 4)
_B, _N, _D = coords.shape
+ device = coords.device.type
# disable padded coords (such that it doesnt affect minimum)
if padding_mask is not None:
@@ -384,15 +331,67 @@ def forward(self, coords, features=None, padding_mask=None):
min_time = coords[:, :, :1].min(dim=1, keepdims=True).values
coords = coords - min_time
- pos = self.pos_embed(coords)
-
- if features is None or features.numel() == 0:
- features = pos
+ if self._disable_xy_coords:
+ coords_feat = coords[:, :, :1].clone()
else:
- features = self.feat_embed(features)
- features = torch.cat((pos, features), axis=-1)
+ coords_feat = coords.clone()
- features = self.proj(features)
+ if not self._disable_all_coords:
+ pos = self.pos_embed(coords_feat)
+ else:
+ pos = None
+
+ if self._return_norms:
+ self.norms = {}
+ if not self._disable_all_coords:
+ self.norms["pos_embed"] = pos.norm(dim=-1).detach().cpu().mean().item()
+ self.norms["coords"] = coords_feat.norm(dim=-1).detach().cpu().mean().item()
+
+ with torch.amp.autocast(enabled=False, device_type=device):
+ # Determine if we have any features to use
+ has_features = features is not None and features.numel() > 0
+ has_pretrained = pretrained_features is not None and pretrained_features.numel() > 0 and self.config["pretrained_feat_dim"] > 0
+
+ if self._return_norms:
+ if has_features:
+ self.norms["features"] = features.norm(dim=-1).detach().cpu().mean().item()
+ if has_pretrained:
+ self.norms["pretrained_features"] = pretrained_features.norm(dim=-1).detach().cpu().mean().item()
+
+ if not has_features and not has_pretrained:
+ if self._disable_all_coords:
+ raise ValueError("features is None and all coords are disabled. Please enable at least one of the two.")
+ features_out = pos
+ else:
+ # Start with features if present, else None
+ features_out = self.feat_embed(features) if has_features else None
+ if self._return_norms and has_features:
+ self.norms["features_out"] = features_out.norm(dim=-1).detach().cpu().mean().item()
+
+ # Add pretrained features if configured
+ if self.config["pretrained_feat_dim"] > 0 and has_pretrained:
+ pt_features = self.ptfeat_proj(pretrained_features)
+ pt_features = self.ptfeat_norm(pt_features)
+ if self._return_norms:
+ self.norms["pt_features_out"] = pt_features.norm(dim=-1).detach().cpu().mean().item()
+ if features_out is not None:
+ features_out = torch.cat((features_out, pt_features), dim=-1)
+ else:
+ features_out = pt_features
+
+ # Add encoded coords if not disabled
+ if not self._disable_all_coords:
+ if features_out is not None:
+ features_out = torch.cat((pos, features_out), axis=-1)
+ else:
+ features_out = pos
+
+ features = self.proj(features_out)
+ if self._return_norms:
+ self.norms["features_cat"] = features_out.norm(dim=-1).detach().cpu().mean().item()
+ self.norms["features_proj"] = features.norm(dim=-1).detach().cpu().mean().item()
+ # Clamp input when returning to mixed precision
+ features = features.clamp(torch.finfo(torch.float16).min, torch.finfo(torch.float16).max)
features = self.norm(features)
x = features
@@ -411,7 +410,10 @@ def forward(self, coords, features=None, padding_mask=None):
y = self.head_y(y)
# outer product is the association matrix (logits)
- A = torch.einsum("bnd,bmd->bnm", x, y)
+ A = torch.einsum("bnd,bmd->bnm", x, y) # /math.sqrt(_D)
+
+ if torch.any(torch.isnan(A)):
+ logger.error("NaN in A")
return A
diff --git a/trackastra/model/model_api.py b/trackastra/model/model_api.py
index ecdb552..90fddbb 100644
--- a/trackastra/model/model_api.py
+++ b/trackastra/model/model_api.py
@@ -3,13 +3,14 @@
from pathlib import Path
from typing import Literal
+import dask.array as da
import numpy as np
import tifffile
import torch
import yaml
from tqdm import tqdm
-from ..data import build_windows, get_features, load_tiff_timeseries
+from ..data import FeatureExtractor, build_windows, get_features, load_tiff_timeseries
from ..tracking import TrackGraph, build_graph, track_greedy
from ..utils import normalize
from .model import TrackingTransformer
@@ -21,12 +22,50 @@
class Trackastra:
+ """A transformer-based tracking model for time-lapse data.
+
+ Trackastra links segmented objects across time frames by predicting
+ associations with a transformer model trained on diverse time-lapse videos.
+
+ The model takes as input:
+ - A sequence of images of shape (T,(Z),Y,X)
+ - Corresponding instance segmentation masks of shape (T,(Z),Y,X)
+
+ It supports multiple tracking modes:
+ - greedy_nodiv: Fast greedy linking without division
+ - greedy: Fast greedy linking with division
+ - ilp: Integer Linear Programming based linking (more accurate but slower)
+
+ Examples:
+ >>> # Load example data
+ >>> from trackastra.data import example_data_bacteria
+ >>> imgs, masks = example_data_bacteria()
+ >>>
+ >>> # Load pretrained model and track
+ >>> model = Trackastra.from_pretrained("general_2d", device="cuda")
+ >>> track_graph = model.track(imgs, masks, mode="greedy")
+ """
+
def __init__(
self,
transformer: TrackingTransformer,
train_args: dict,
device: Literal["cuda", "mps", "cpu", "automatic", None] = None,
):
+ """Initialize Trackastra model.
+
+ Args:
+ transformer: The underlying transformer model.
+ train_args: Training configuration arguments.
+ device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
+ """
+ """Initialize Trackastra model.
+
+ Args:
+ transformer: The underlying transformer model.
+ train_args: Training configuration arguments.
+ device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
+ """
if device == "cuda":
if torch.cuda.is_available():
self.device = "cuda"
@@ -67,38 +106,76 @@ def __init__(
self.transformer = transformer.to(self.device)
self.train_args = train_args
+ self.imgs_path = None
+ self.masks_path = None
+ self.feature_extractor = None
@classmethod
- def from_folder(cls, dir: Path, device: str | None = None):
+ def from_folder(cls, dir: Path | str, device: str | None = None, checkpoint_path: str | None = None):
+ """Load a Trackastra model from a local folder.
+
+ Args:
+ dir: Path to model folder containing:
+ - model weights
+ - train_config.yaml with training arguments
+ device: Device to run model on.
+ checkpoint_path: Path to model checkpoint file (defaults to "model.pt" in dir).
+
+ Returns:
+ Trackastra model instance.
+ """
# Always load to cpu first
- transformer = TrackingTransformer.from_folder(dir, map_location="cpu")
+ if checkpoint_path is None:
+ checkpoint_path = "model.pt"
+ transformer = TrackingTransformer.from_folder(
+ Path(dir).expanduser(), map_location="cpu", checkpoint_path=checkpoint_path
+ )
train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader)
return cls(transformer=transformer, train_args=train_args, device=device)
- # TODO make safer
@classmethod
def from_pretrained(
cls, name: str, device: str | None = None, download_dir: Path | None = None
):
+ """Load a pretrained Trackastra model.
+
+ Available pretrained models are described in detail in pretrained.json.
+
+ Args:
+ name: Name of pretrained model (e.g. "general_2d").
+ device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
+ download_dir: Directory to download model to (defaults to ~/.cache/trackastra).
+
+ Returns:
+ Trackastra model instance.
+ """
folder = download_pretrained(name, download_dir)
# download zip from github to location/name, then unzip
return cls.from_folder(folder, device=device)
def _predict(
self,
- imgs: np.ndarray,
- masks: np.ndarray,
+ imgs: np.ndarray | da.Array,
+ masks: np.ndarray | da.Array,
edge_threshold: float = 0.05,
n_workers: int = 0,
+ normalize_imgs: bool = True,
progbar_class=tqdm,
):
logger.info("Predicting weights for candidate graph")
- imgs = normalize(imgs)
+ if normalize_imgs:
+ if isinstance(imgs, da.Array):
+ imgs = imgs.map_blocks(normalize)
+ else:
+ imgs = normalize(imgs)
+
self.transformer.eval()
features = get_features(
detections=masks,
imgs=imgs,
+ features_type=self.train_args["features"],
+ feature_extractor=self.feature_extractor,
ndim=self.transformer.config["coord_dim"],
n_workers=n_workers,
progbar_class=progbar_class,
@@ -158,13 +235,68 @@ def _track_from_predictions(
def track(
self,
- imgs: np.ndarray,
- masks: np.ndarray,
+ imgs: np.ndarray | da.Array,
+ masks: np.ndarray | da.Array,
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
+ normalize_imgs: bool = True,
progbar_class=tqdm,
+ n_workers: int = 0,
**kwargs,
) -> TrackGraph:
- predictions = self._predict(imgs, masks, progbar_class=progbar_class)
+ """Track objects across time frames.
+
+ This method links segmented objects across time frames using the specified
+ tracking mode. No hyperparameters need to be chosen beyond the tracking mode.
+
+ Args:
+ imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array)
+ masks: Instance segmentation masks of shape (T,(Z),Y,X).
+ mode: Tracking mode:
+ - "greedy_nodiv": Fast greedy linking without division
+ - "greedy": Fast greedy linking with division
+ - "ilp": Integer Linear Programming based linking (more accurate but slower)
+ progbar_class: Progress bar class to use.
+ n_workers: Number of worker processes for feature extraction.
+ normalize_imgs: Whether to normalize the images.
+ **kwargs: Additional arguments passed to tracking algorithm.
+
+ Returns:
+ TrackGraph containing the tracking results.
+ """
+ if not imgs.shape == masks.shape:
+ raise RuntimeError(
+ f"Img shape {imgs.shape} and mask shape {masks.shape} do not match."
+ )
+
+ if not imgs.ndim == self.transformer.config["coord_dim"] + 1:
+ raise RuntimeError(
+ f"images should be a sequence of {self.transformer.config['coord_dim']}D images"
+ )
+ feat_type = self.train_args["features"]
+ if feat_type == "pretrained_feats" or feat_type == "pretrained_feats_aug":
+ additional_features = self.train_args.get(
+ "pretrained_feats_additional_props", None
+ )
+ if self.imgs_path is None:
+ save_path = "./embeddings"
+ else:
+ save_path = self.imgs_path / "embeddings"
+ self.feature_extractor = FeatureExtractor.from_model_name(
+ self.train_args["pretrained_feats_model"],
+ imgs.shape[-2:],
+ save_path=save_path,
+ mode=self.train_args["pretrained_feats_mode"],
+ device="cuda" if torch.cuda.is_available() else "cpu",
+ additional_features=additional_features,
+ )
+ self.feature_extractor.force_recompute = True
+ predictions = self._predict(
+ imgs,
+ masks,
+ normalize_imgs=normalize_imgs,
+ progbar_class=progbar_class,
+ n_workers=n_workers,
+ )
track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs)
return track_graph
@@ -173,27 +305,38 @@ def track_from_disk(
imgs_path: Path,
masks_path: Path,
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
+ normalize_imgs: bool = True,
**kwargs,
) -> tuple[TrackGraph, np.ndarray]:
- """Track directly from two series of tiff files.
+ """Track objects directly from image and mask files on disk.
+
+ This method supports both single tiff files and directories
Args:
- imgs_path:
- Options
- - Directory containing a series of numbered tiff files. Each file contains an image of shape (C),(Z),Y,X.
- - Single tiff file with time series of shape T,(C),(Z),Y,X.
- masks_path:
- Options
- - Directory containing a series of numbered tiff files. Each file contains an image of shape (C),(Z),Y,X.
- - Single tiff file with time series of shape T,(Z),Y,X.
- mode (optional):
- Mode for candidate graph pruning.
+ imgs_path: Path to input images. Can be:
+ - Directory containing numbered tiff files of shape (C),(Z),Y,X
+ - Single tiff file with time series of shape T,(C),(Z),Y,X
+ masks_path: Path to mask files. Can be:
+ - Directory containing numbered tiff files of shape (Z),Y,X
+ - Single tiff file with time series of shape T,(Z),Y,X
+ mode: Tracking mode:
+ - "greedy_nodiv": Fast greedy linking without division
+ - "greedy": Fast greedy linking with division
+ - "ilp": Integer Linear Programming based linking (more accurate but slower)
+ normalize_imgs: Whether to normalize the images.
+ **kwargs: Additional arguments passed to tracking algorithm.
+
+ Returns:
+ Tuple of (TrackGraph, tracked masks).
"""
if not imgs_path.exists():
raise FileNotFoundError(f"{imgs_path=} does not exist.")
if not masks_path.exists():
raise FileNotFoundError(f"{masks_path=} does not exist.")
+ self.imgs_path = imgs_path
+ self.masks_path = masks_path
+
if imgs_path.is_dir():
imgs = load_tiff_timeseries(imgs_path)
else:
@@ -226,4 +369,6 @@ def track_from_disk(
f"Img shape {imgs.shape} and mask shape {masks. shape} do not match."
)
- return self.track(imgs, masks, mode, **kwargs), masks
+ return self.track(
+ imgs, masks, mode, normalize_imgs=normalize_imgs, **kwargs
+ ), masks
diff --git a/trackastra/model/model_parts.py b/trackastra/model/model_parts.py
index 5f6e042..b25c8a1 100644
--- a/trackastra/model/model_parts.py
+++ b/trackastra/model/model_parts.py
@@ -8,6 +8,17 @@
import torch.nn.functional as F
from torch import nn
+try:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
+ _FLASH_ATTN = True and torch.cuda.is_available()
+except ImportError:
+ flash_attn_varlen_qkvpacked_func = None
+ _FLASH_ATTN = False
+
+# if not _FLASH_ATTN:
+# warnings.warn("flash_attn not found or not available for device, falling back to normal attention.")
+# warnings.warn("Install with\n\npip install flash-attn --no-build-isolation\n\n")
+
from .rope import RotaryPositionalEncoding
logger = logging.getLogger(__name__)
@@ -58,7 +69,7 @@ def __init__(
def forward(self, coords: torch.Tensor):
_B, _N, D = coords.shape
- assert D == len(self.freqs)
+ assert D == len(self.freqs), f"coords dim {D} must be equal to number of frequencies {len(self.freqs)}"
embed = torch.cat(
tuple(
torch.cat(
@@ -169,6 +180,13 @@ def __init__(
mode: Literal["bias", "rope", "none"] = "bias",
attn_dist_mode: str = 'v0'
):
+ """
+
+ attn_dist_mode: str
+ v0: exponential decay
+ v1: exponential decay with cutoff_spatial
+ v2: no masking (except padding_mask).
+ """
super().__init__()
if not embed_dim % (2 * n_head) == 0:
@@ -220,13 +238,13 @@ def __init__(
self._mode = mode
def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- coords: torch.Tensor,
- padding_mask: torch.Tensor = None,
- ):
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ coords: torch.Tensor,
+ padding_mask: torch.Tensor = None,
+ ):
B, N, D = query.size()
q = self.q_pro(query) # (B, N, D)
k = self.k_pro(key) # (B, N, D)
@@ -243,11 +261,12 @@ def forward(
# add negative value but not too large to keep mixed precision loss from becoming nan
attn_ignore_val = -1e3
- # spatial cutoff
- yx = coords[..., 1:]
- spatial_dist = torch.cdist(yx, yx)
- spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1)
- attn_mask.masked_fill_(spatial_mask, attn_ignore_val)
+ if self.attn_dist_mode != 'v2':
+ # spatial cutoff
+ yx = coords[..., 1:]
+ spatial_dist = torch.cdist(yx, yx)
+ spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1)
+ attn_mask.masked_fill_(spatial_mask, attn_ignore_val)
# dont add positional bias to self-attention if coords is None
if coords is not None:
@@ -263,6 +282,8 @@ def forward(
attn_mask += torch.exp(-0.1 * dist.unsqueeze(1))
elif self.attn_dist_mode == 'v1':
attn_mask += torch.exp(-5 * spatial_dist.unsqueeze(1) / self.cutoff_spatial)
+ elif self.attn_dist_mode == 'v2':
+ pass
else:
raise ValueError(f"Unknown attn_dist_mode {self.attn_dist_mode}")
@@ -271,15 +292,200 @@ def forward(
ignore_mask = torch.logical_or(
padding_mask.unsqueeze(1), padding_mask.unsqueeze(2)
).unsqueeze(1)
- attn_mask.masked_fill_(ignore_mask, attn_ignore_val)
+ if self.attn_dist_mode == 'v2':
+ attn_mask = ~ignore_mask
+ else:
+ attn_mask.masked_fill_(ignore_mask, attn_ignore_val)
- # self.attn_mask = attn_mask.clone()
+ self.attn_mask = attn_mask.clone()
- y = F.scaled_dot_product_attention(
- q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0
- )
+ if _FLASH_ATTN and self.attn_dist_mode == 'v2' and False: # Disable for now
+ y = compute_attention_with_unpadding(q, k, v, padding_mask)
+ else:
+ y = F.scaled_dot_product_attention(
+ q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0
+ )
y = y.transpose(1, 2).contiguous().view(B, N, D)
# output projection
y = self.proj(y)
return y
+
+
+def compute_attention_with_unpadding(q, k, v, padding_mask):
+ """Compute self-attention using flash_attn_varlen_qkvpacked_func with unpadding.
+
+ Args:
+ q, k, v: Tensors of shape (B, H, N, D)
+ padding_mask: Tensor of shape (B, N), where True means the element should be ignored.
+
+ Returns:
+ output: Tensor of shape (B, H, N, D), the result of the attention computation.
+ """
+ B, H, N, D = q.shape
+ assert q.shape == k.shape == v.shape
+ assert padding_mask.shape == (B, N)
+
+ # Extract sequence lengths and create cumulative sequence lengths
+ valid_tokens_mask = ~padding_mask # Flip the mask so True means valid
+ lens = valid_tokens_mask.sum(dim=-1).tolist() # Length of each sequence
+ cu_seqlens = torch.tensor([0, *torch.cumsum(torch.tensor(lens), dim=0).tolist()], dtype=torch.int32, device=q.device)
+
+ # Unpad Q, K, V
+ q_unpadded = q.transpose(1, 2)[valid_tokens_mask] # Shape: (total_tokens, H, D)
+ k_unpadded = k.transpose(1, 2)[valid_tokens_mask] # Shape: (total_tokens, H, D)
+ v_unpadded = v.transpose(1, 2)[valid_tokens_mask] # Shape: (total_tokens, H, D)
+
+ # Stack Q, K, V into a single tensor for FlashAttention
+ qkv_unpadded = torch.stack([q_unpadded, k_unpadded, v_unpadded], dim=1) # Shape: (total_tokens, 3, H, D)
+
+ qkv_unpadded = qkv_unpadded.bfloat16()
+ # FlashAttention
+ max_seqlen = max(lens) # Maximum sequence length in the batch
+ output_unpadded = flash_attn_varlen_qkvpacked_func(
+ qkv_unpadded, # (total_tokens, 3, H, D)
+ cu_seqlens, # (B + 1,)
+ max_seqlen=max_seqlen,
+ dropout_p=0.0, # Set to 0.0 for evaluation
+ causal=False,
+ ) # Output: (total_tokens, H, D)
+
+ output_unpadded = output_unpadded.to(q.dtype)
+ # Re-pad to original dimensions
+ output_padded = torch.zeros((B, N, H, D), dtype=output_unpadded.dtype, device=output_unpadded.device)
+ output_padded[valid_tokens_mask] = output_unpadded
+ output_padded = output_padded.transpose(1, 2) # Shape: (B, H, N, D)
+
+ return output_padded
+
+# class BidirectionalRelativePositionalAttention(RelativePositionalAttention):
+# def forward(
+# self,
+# query1: torch.Tensor,
+# query2: torch.Tensor,
+# coords: torch.Tensor,
+# padding_mask: torch.Tensor = None,
+# ):
+# B, N, D = query1.size()
+# q1 = self.q_pro(query1) # (B, N, D)
+# q2 = self.q_pro(query2) # (B, N, D)
+# v1 = self.v_pro(query1) # (B, N, D)
+# v2 = self.v_pro(query2) # (B, N, D)
+
+# # (B, nh, N, hs)
+# q1 = q1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
+# v1 = v1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
+# q2 = q2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
+# v2 = v2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
+
+# attn_mask = torch.zeros(
+# (B, self.n_head, N, N), device=query1.device, dtype=q1.dtype
+# )
+
+# # add negative value but not too large to keep mixed precision loss from becoming nan
+# attn_ignore_val = -1e3
+
+# # spatial cutoff
+# yx = coords[..., 1:]
+# spatial_dist = torch.cdist(yx, yx)
+# spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1)
+# attn_mask.masked_fill_(spatial_mask, attn_ignore_val)
+
+# # dont add positional bias to self-attention if coords is None
+# if coords is not None:
+# if self._mode == "bias":
+# attn_mask = attn_mask + self.pos_bias(coords)
+# elif self._mode == "rope":
+# q1, q2 = self.rot_pos_enc(q1, q2, coords)
+# else:
+# pass
+
+# dist = torch.cdist(coords, coords, p=2)
+# attn_mask += torch.exp(-0.1 * dist.unsqueeze(1))
+
+# # if given key_padding_mask = (B,N) then ignore those tokens (e.g. padding tokens)
+# if padding_mask is not None:
+# ignore_mask = torch.logical_or(
+# padding_mask.unsqueeze(1), padding_mask.unsqueeze(2)
+# ).unsqueeze(1)
+# attn_mask.masked_fill_(ignore_mask, attn_ignore_val)
+
+# self.attn_mask = attn_mask.clone()
+
+# y1 = nn.functional.scaled_dot_product_attention(
+# q1,
+# q2,
+# v1,
+# attn_mask=attn_mask,
+# dropout_p=self.dropout if self.training else 0,
+# )
+# y2 = nn.functional.scaled_dot_product_attention(
+# q2,
+# q1,
+# v2,
+# attn_mask=attn_mask,
+# dropout_p=self.dropout if self.training else 0,
+# )
+
+# y1 = y1.transpose(1, 2).contiguous().view(B, N, D)
+# y1 = self.proj(y1)
+# y2 = y2.transpose(1, 2).contiguous().view(B, N, D)
+# y2 = self.proj(y2)
+# return y1, y2
+
+
+# class BidirectionalCrossAttention(nn.Module):
+# def __init__(
+# self,
+# coord_dim: int = 2,
+# d_model=256,
+# num_heads=4,
+# dropout=0.1,
+# window: int = 16,
+# cutoff_spatial: int = 256,
+# positional_bias: Literal["bias", "rope", "none"] = "bias",
+# positional_bias_n_spatial: int = 32,
+# ):
+# super().__init__()
+# self.positional_bias = positional_bias
+# self.attn = BidirectionalRelativePositionalAttention(
+# coord_dim,
+# d_model,
+# num_heads,
+# cutoff_spatial=cutoff_spatial,
+# n_spatial=positional_bias_n_spatial,
+# cutoff_temporal=window,
+# n_temporal=window,
+# dropout=dropout,
+# mode=positional_bias,
+# )
+
+# self.mlp = FeedForward(d_model)
+# self.norm1 = nn.LayerNorm(d_model)
+# self.norm2 = nn.LayerNorm(d_model)
+
+# def forward(
+# self,
+# x: torch.Tensor,
+# y: torch.Tensor,
+# coords: torch.Tensor,
+# padding_mask: torch.Tensor = None,
+# ):
+# x = self.norm1(x)
+# y = self.norm1(y)
+
+# # cross attention
+# # setting coords to None disables positional bias
+# x2, y2 = self.attn(
+# x,
+# y,
+# coords=coords if self.positional_bias else None,
+# padding_mask=padding_mask,
+# )
+# # print(torch.norm(x2).item()/torch.norm(x).item())
+# x = x + x2
+# x = x + self.mlp(self.norm2(x))
+# y = y + y2
+# y = y + self.mlp(self.norm2(y))
+
+# return x, y
\ No newline at end of file
diff --git a/trackastra/model/predict.py b/trackastra/model/predict.py
index afcfefb..efb4e82 100644
--- a/trackastra/model/predict.py
+++ b/trackastra/model/predict.py
@@ -17,26 +17,41 @@
def predict(batch, model):
- """Args:
- batch (_type_): _description_
- model (_type_): _description_.
-
+ """Predict association scores between objects in a batch.
+
+ Args:
+ batch: Dictionary containing:
+ - features: Object features array
+ - coords: Object coordinates array
+ - timepoints: Time points array
+ model: TrackingTransformer model to use for prediction.
+
Returns:
- _type_: _description_
+ Array of association scores between objects.
"""
- feats = torch.from_numpy(batch["features"])
+ if batch["features"] is not None:
+ feats = torch.from_numpy(batch["features"])
+ else:
+ feats = None
+ if batch["pretrained_features"] is not None:
+ pretrained_feats = torch.from_numpy(batch["pretrained_features"])
+ else:
+ pretrained_feats = None
coords = torch.from_numpy(batch["coords"])
timepoints = torch.from_numpy(batch["timepoints"]).long()
# Hack that assumes that all parameters of a model are on the same device
device = next(model.parameters()).device
- feats = feats.unsqueeze(0).to(device)
- timepoints = timepoints.unsqueeze(0).to(device)
coords = coords.unsqueeze(0).to(device)
+ timepoints = timepoints.unsqueeze(0).to(device)
+ if feats is not None:
+ feats = feats.unsqueeze(0).to(device)
+ if pretrained_feats is not None:
+ pretrained_feats = pretrained_feats.unsqueeze(0).to(device)
# Concat timepoints to coordinates
coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2)
with torch.no_grad():
- A = model(coords, features=feats)
+ A = model(coords, features=feats, pretrained_features=pretrained_feats)
A = model.normalize_output(A, timepoints, coords)
# # Spatially far entries should not influence the causal normalization
@@ -60,20 +75,39 @@ def predict_windows(
edge_threshold: float = 0.05,
spatial_dim: int = 3,
progbar_class=tqdm,
+ pred_func_override=None,
) -> dict:
- """_summary_.
-
+ """Predict associations between objects across sliding windows.
+
+ This function processes a sequence of sliding windows to predict associations
+ between objects across time frames. It handles:
+ - Object tracking across time
+ - Weight normalization across windows
+ - Edge thresholding
+ - Time-based filtering
+
Args:
- windows (_type_): _description_
- features (_type_): _description_
- model (_type_): _description_
- intra_window_weight (_type_, optional): _description_. Defaults to 0.
- delta_t (_type_, optional): _description_. Defaults to 1.
- edge_threshold (_type_, optional): _description_. Defaults to 0.05.
- spatial_dim: Dimensionality of the input masks. This might be < model.coord_dim
+ windows: List of window dictionaries containing:
+ - timepoints: Array of time points
+ - labels: Array of object labels
+ - features: Object features
+ - coords: Object coordinates
+ features: List of feature objects containing:
+ - labels: Object labels
+ - timepoints: Time points
+ - coords: Object coordinates
+ model: TrackingTransformer model to use for prediction.
+ intra_window_weight: Weight factor for objects in middle of window. Defaults to 0.
+ delta_t: Maximum time difference between objects to consider. Defaults to 1.
+ edge_threshold: Minimum association score to consider. Defaults to 0.05.
+ spatial_dim: Dimensionality of input masks. May be less than model.coord_dim.
+ progbar_class: Progress bar class to use. Defaults to tqdm.
+ pred_func_override: Function to override the prediction function. This is useful for debugging or testing other prediction methods.
Returns:
- _type_: _description_
+ Dictionary containing:
+ - nodes: List of node properties (id, coords, time, label)
+ - weights: Tuple of ((node_i, node_j), weight) pairs
"""
# first get all objects/coords
time_labels_to_id = dict()
@@ -109,10 +143,17 @@ def predict_windows(
# This assumes that the samples in the dataset are ordered by time and start at 0.
batch = windows[t]
timepoints = batch["timepoints"]
+ if isinstance(timepoints, torch.Tensor):
+ timepoints = timepoints.cpu().numpy()
labels = batch["labels"]
-
- A = predict(batch, model)
-
+ if isinstance(labels, torch.Tensor):
+ labels = labels.cpu().numpy()
+
+ if pred_func_override is None:
+ A = predict(batch, model)
+ else:
+ A = pred_func_override(batch)
+
dt = timepoints[None, :] - timepoints[:, None]
time_mask = np.logical_and(dt <= delta_t, dt > 0)
A[~time_mask] = 0
diff --git a/trackastra/model/pretrained.json b/trackastra/model/pretrained.json
index 0d7e45d..a00137d 100644
--- a/trackastra/model/pretrained.json
+++ b/trackastra/model/pretrained.json
@@ -1,9 +1,9 @@
{
"general_2d": {
- "tags": ["cells, nuclei, bacteria, epithelial"],
+ "tags": ["cells, nuclei, bacteria, epithelial, yeast, particles"],
"dimensionality": [2],
- "description": "For tracking fluorescent nuclei, bacteria (PhC), whole cells (BF, PhC, DIC), epithelial cells with fluorescent membrane.",
- "url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.1.1/general_2d.zip",
+ "description": "For tracking fluorescent nuclei, bacteria (PhC), whole cells (BF, PhC, DIC), epithelial cells with fluorescent membrane, budding yeast cells (PhC), fluorescent particles, .",
+ "url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/general_2d.zip",
"datasets": {
"Subset of Cell Tracking Challenge 2d datasets": {
"url": "https://celltrackingchallenge.net/2d-datasets/",
@@ -17,6 +17,10 @@
"url": "https://zenodo.org/records/7260137",
"reference": "Seiffarth J, Scherr T, Wollenhaupt B, Neumann O, Scharr H, Kohlheyer D, Mikut R, Nöh K. ObiWan-Microbi: OMERO-based integrated workflow for annotating microbes in the cloud. SoftwareX. 2024 May 1;26:101638."
},
+ "Bacteria Persat": {
+ "url": "https://www.p-lab.science",
+ "reference": "Datasets kindly provided by Persat lab, EPFL."
+ },
"DeepCell": {
"url": "https://datasets.deepcell.org/data",
"reference": "Schwartz, M, Moen E, Miller G, Dougherty T, Borba E, Ding R, Graf W, Pao E, Van Valen D. Caliban: Accurate cell tracking and lineage construction in live-cell imaging experiments with deep learning. Biorxiv. 2023 Sept 13:803205."
@@ -36,14 +40,37 @@
},
"Synthetic nuclei": {
"reference": "Weigert group live cell simulator."
+ },
+ "Synthetic particles": {
+ "reference": "Weigert group particle simulator."
+ },
+ "Particle Tracking Challenge": {
+ "url": "http://bioimageanalysis.org/track/#data",
+ "reference": "Chenouard, N., Smal, I., De Chaumont, F., Maška, M., Sbalzarini, I. F., Gong, Y., ... & Meijering, E. (2014). Objective comparison of particle tracking methods. Nature methods, 11(3), 281-289."
+ },
+ "Yeast Cell-ACDC": {
+ "url": "https://zenodo.org/records/6795124",
+ "reference": "Padovani, F., Mairhörmann, B., Falter-Braun, P., Lengefeld, J., & Schmoller, K. M. (2022). Segmentation, tracking and cell cycle analysis of live-cell imaging data with Cell-ACDC. BMC biology, 20(1), 174."
+ },
+ "DeepSea": {
+ "url": "https://deepseas.org/datasets/",
+ "reference": "Zargari, A., Lodewijk, G. A., Mashhadi, N., Cook, N., Neudorf, C. W., Araghbidikashani, K., ... & Shariati, S. A. (2023). DeepSea is an efficient deep-learning model for single-cell segmentation and tracking in time-lapse microscopy. Cell Reports Methods, 3(6)."
+ },
+ "Btrack" : {
+ "url": "https://rdr.ucl.ac.uk/articles/dataset/Cell_tracking_reference_dataset/16595978",
+ "reference": "Ulicna, K., Vallardi, G., Charras, G., & Lowe, A. R. (2021). Automated deep lineage tree analysis using a Bayesian single cell tracking approach. Frontiers in Computer Science, 3, 734559."
+ },
+ "E. coli in mother machine": {
+ "url": "https://zenodo.org/records/11237127",
+ "reference": "O’Connor, O. M., & Dunlop, M. J. (2024). Cell-TRACTR: A transformer-based model for end-to-end segmentation and tracking of cells. bioRxiv, 2024-07."
}
}
},
"ctc": {
- "tags": ["ctc", "isbi2024"],
+ "tags": ["ctc", "Cell Tracking Challenge", "Cell Linking Benchmark"],
"dimensionality": [2, 3],
- "description": "For tracking Cell Tracking Challenge datasets. Winner of the ISBI 2024 CTC generalizable linking challenge.",
- "url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.1/ctc.zip",
+ "description": "For tracking Cell Tracking Challenge datasets. This is the successor of the winning model of the ISBI 2024 CTC generalizable linking challenge.",
+ "url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/ctc.zip",
"datasets": {
"All Cell Tracking Challenge 2d+3d datasets with available GT and ERR_SEG": {
"url": "https://celltrackingchallenge.net/3d-datasets/",
diff --git a/trackastra/model/pretrained.py b/trackastra/model/pretrained.py
index fcb86e9..3a3c047 100644
--- a/trackastra/model/pretrained.py
+++ b/trackastra/model/pretrained.py
@@ -2,6 +2,7 @@
import shutil
import tempfile
import zipfile
+from importlib.resources import files
from pathlib import Path
import requests
@@ -10,10 +11,8 @@
logger = logging.getLogger(__name__)
_MODELS = {
- "ctc": (
- "https://github.com/weigertlab/trackastra-models/releases/download/v0.1/ctc.zip"
- ),
- "general_2d": "https://github.com/weigertlab/trackastra-models/releases/download/v0.1.1/general_2d.zip",
+ "ctc": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/ctc.zip",
+ "general_2d": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/general_2d.zip",
}
@@ -57,7 +56,7 @@ def download(url: str, fname: Path):
def download_pretrained(name: str, download_dir: Path | None = None):
# TODO make safe, introduce versioning
if download_dir is None:
- download_dir = Path("~/.trackastra/.models").expanduser()
+ download_dir = files("trackastra").joinpath(".models")
else:
download_dir = Path(download_dir)
diff --git a/trackastra/model/rope.py b/trackastra/model/rope.py
index e1f027a..5beef41 100644
--- a/trackastra/model/rope.py
+++ b/trackastra/model/rope.py
@@ -75,6 +75,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor):
co = co.unsqueeze(1).repeat_interleave(2, dim=-1)
si = si.unsqueeze(1).repeat_interleave(2, dim=-1)
+
q2 = q * co + _rotate_half(q) * si
k2 = k * co + _rotate_half(k) * si
@@ -84,12 +85,14 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor):
if __name__ == "__main__":
model = RotaryPositionalEncoding((256, 256), (32, 32))
- x = 100 * torch.rand(1, 17, 2)
- q = torch.rand(1, 4, 17, 64)
- k = torch.rand(1, 4, 17, 64)
+ x = 100 * torch.randn(1, 17, 2)
+ q = torch.randn(1, 4, 17, 64)
+ k = torch.randn(1, 4, 17, 64)
q1, k1 = model(q, k, x)
A1 = q1[:, :, 0] @ k1[:, :, 0].transpose(-1, -2)
q2, k2 = model(q, k, x + 10)
A2 = q2[:, :, 0] @ k2[:, :, 0].transpose(-1, -2)
+
+ print("close", torch.allclose(A1, A2))
\ No newline at end of file
diff --git a/trackastra/utils/__init__.py b/trackastra/utils/__init__.py
index b9f3239..e1fb217 100644
--- a/trackastra/utils/__init__.py
+++ b/trackastra/utils/__init__.py
@@ -1,8 +1,10 @@
# ruff: noqa: F401
from .utils import (
+ add_timepoints_to_coords,
blockwise_causal_norm,
blockwise_sum,
+ masks2properties,
normalize,
preallocate_memory,
random_label_cmap,
diff --git a/trackastra/utils/utils.py b/trackastra/utils/utils.py
index 730268c..7cfc767 100644
--- a/trackastra/utils/utils.py
+++ b/trackastra/utils/utils.py
@@ -6,9 +6,12 @@
from pathlib import Path
from timeit import default_timer
+import dask.array as da
import matplotlib
import numpy as np
import torch
+from skimage.measure import regionprops
+from tqdm import tqdm
logger = logging.getLogger(__name__)
@@ -159,6 +162,73 @@ def _blockwise_sum_with_bounds(A: torch.Tensor, bounds: torch.Tensor, dim: int =
return B
+def add_timepoints_to_coords(coords, timepoints):
+ """Adds timepoints as the first dimension to the coordinates.
+
+ Args:
+ coords (np.ndarray): Coordinates of shape (N, 3) or (N, 2).
+ timepoints (np.ndarray): Timepoints of shape (N,).
+
+ Returns:
+ np.ndarray: Coordinates with timepoints added as the first dimension of shape (N, 3) or (N, 4).
+ """
+ if coords.ndim not in [2] or coords.shape[1] not in [2, 3]:
+ raise ValueError("coords must be a 2D array with shape (N, 2) or (N, 3).")
+ if timepoints.ndim != 1 or timepoints.shape[0] != coords.shape[0]:
+ raise ValueError("timepoints must be a 1D array with the same length as the first dimension of coords.")
+
+ return np.column_stack((timepoints, coords))
+
+
+def masks2properties(imgs, masks, return_props_by_time=False):
+ """Turn label masks into lists of properties, sorted (ascending) by time and label id.
+
+ Args:
+ imgs (np.ndarray): T, (Z), H, W
+ masks (np.ndarray): T, (Z), H, W
+ return_props_by_time (bool): If True, return properties by time
+ (dict with keys 'coords' and 'labels' for each timepoint)
+
+ Returns:
+ labels: List of labels
+ ts: List of timepoints
+ coords: List of coordinates
+ """
+ # Get coordinates, timepoints, and labels of detections
+ labels = []
+ ts = []
+ coords = []
+ properties_by_time = dict()
+ assert len(imgs) == len(masks)
+ for _t, frame in tqdm(
+ enumerate(masks),
+ # total=len(detections),
+ leave=False,
+ desc="Loading masks and properties",
+ ):
+ regions = regionprops(frame)
+ t_labels = []
+ t_ts = []
+ t_coords = []
+ for _r in regions:
+ t_labels.append(_r.label)
+ t_ts.append(_t)
+ centroid = np.array(_r.centroid).astype(np.float16)
+ t_coords.append(centroid)
+
+ properties_by_time[_t] = dict(coords=t_coords, labels=t_labels)
+ labels.extend(t_labels)
+ ts.extend(t_ts)
+ coords.extend(t_coords)
+
+ labels = np.array(labels, dtype=int)
+ ts = np.array(ts, dtype=int)
+ coords = np.array(coords, dtype=np.float16)
+ if return_props_by_time:
+ return labels, ts, coords, properties_by_time
+ return labels, ts, coords
+
+
def _bounds_from_timepoints(timepoints: torch.Tensor):
assert timepoints.ndim == 1
bounds = torch.cat(
@@ -263,7 +333,8 @@ def blockwise_causal_norm(
A = torch.sigmoid(A)
if mask_invalid is not None:
assert mask_invalid.shape == A.shape
- A[mask_invalid] = 0
+ # A[mask_invalid] = 0
+ A = A.masked_fill(mask_invalid, 0)
u0, u1 = A, A
ma0 = ma1 = 0
@@ -304,9 +375,22 @@ def normalize_tensor(x: torch.Tensor, dim: int | None = None, eps: float = 1e-8)
return (x - mi) / (ma - mi + eps)
-def normalize(x: np.ndarray):
- mi, ma = np.percentile(x, (1, 99.8)).astype(np.float32)
- return (x - mi) / (ma - mi + 1e-8)
+def normalize(x: np.ndarray | da.Array, subsample: int | None = 4):
+ """Percentile normalize the image.
+
+ If subsample is not None, calculate the percentile values over a subsampled image (last two axis)
+ which is way faster for large images.
+ """
+ x = x.astype(np.float32)
+ if subsample is not None and all(s > 64 * subsample for s in x.shape[-2:]):
+ y = x[..., ::subsample, ::subsample]
+ else:
+ y = x
+
+ mi, ma = np.percentile(y, (1, 99.8)).astype(np.float32)
+ x -= mi
+ x /= ma - mi + 1e-8
+ return x
def batched(x, batch_size, device):
@@ -460,3 +544,11 @@ def str2path(x: str) -> Path:
tps = torch.repeat_interleave(torch.arange(5), 10)
C = blockwise_causal_norm(A, tps)
+
+
+def percentile_norm(b):
+ for i, im in enumerate(b):
+ p1, p99 = np.percentile(im, (1, 99.8))
+ b[i] = (im - p1) / (p99 - p1)
+ b[i] = np.clip(b[i], 0, 1)
+ return b