diff --git a/.cursorrules b/.cursorrules new file mode 100644 index 0000000..bbb12b8 --- /dev/null +++ b/.cursorrules @@ -0,0 +1,15 @@ +You are an expert coder in Python that likes concise, elegant, and self-explanatory code. + +Here are some general rules: + + +- Follow PEP 8 naming conventions: + - snake_case for functions and variables + - PascalCase for classes + - UPPER_CASE for constants +- Write clear comments explaining the rationale behind a complex algorithms but don't be overly verbose +- Use google style docstrings +- Use type hints to improve code readability and catch potential errors. +- Use f-strings for string interpolation +- Prefer pathlib over os.path +- Prefer tuples over lists for immutable data diff --git a/.gitignore b/.gitignore index e19e224..a9c86a8 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,6 @@ cython_debug/ #.idea/ trackastra/_version.py + +CTC_DATA/* +.vscode/settings.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4902265..57732db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,4 +8,4 @@ repos: rev: v0.4.6 hooks: - id: ruff - args: [--fix, --unsafe-fixes, --preview, --verbose] + args: [--fix, --unsafe-fixes, --preview, --verbose, --exit-zero] diff --git a/README.md b/README.md index 0fdc5e9..2faf8e7 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,13 @@ This repository contains the Python implementation of Trackastra. Please first set up a Python environment (with Python version 3.10 or higher), preferably via [conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) or [mamba](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html#mamba-install). +### Simple installation Trackastra can then be installed from PyPI using `pip`: ```bash pip install trackastra ``` +### With ILP support For tracking with an integer linear program (ILP, which is optional) ```bash conda create --name trackastra python=3.10 --no-default-packages @@ -34,8 +36,21 @@ conda activate trackastra conda install -c conda-forge -c gurobi -c funkelab ilpy pip install "trackastra[ilp]" ``` +
+📄

Development installation

+ +```bash +conda create --name trackastra python=3.10 --no-default-packages +conda activate trackastra +conda install -c conda-forge -c gurobi -c funkelab ilpy +git clone https://github.com/weigertlab/trackastra.git +pip install -e "./trackastra[ilp,dev]" +``` -Notes: +
+
+📄

Notes/Troubleshooting

+ - For the optional ILP linking, this will install [`motile`](https://funkelab.github.io/motile/index.html) and binaries for two discrete optimizers: 1. The [Gurobi Optimizer](https://www.gurobi.com/). This is a commercial solver, which requires a valid license. Academic licenses are provided for free, see [here](https://www.gurobi.com/academia/academic-program-and-licenses/) for how to obtain one. @@ -43,6 +58,9 @@ Notes: 2. The [SCIP Optimizer](https://www.scipopt.org/), a free and open source solver. If `motile` does not find a valid Gurobi license, it will fall back to using SCIP. - On MacOS, installing packages into the conda environment before installing `ilpy` can cause problems. - 2024-06-07: On Apple M3 chips, you might have to use the nightly build of `torch` and `torchvision`, or worst case build them yourself. + +
+ ## Usage @@ -113,7 +131,54 @@ v.add_labels(masks_tracked) v.add_tracks(data=napari_tracks, graph=napari_tracks_graph) ``` -### Training a model on your own data +

+ + icon + Fiji (via TrackMate) +

+ +Trackastra is one of the available trackers in [TrackMate](https://imagej.net/plugins/trackmate/). For installation and usage instructions take a look at this [tutorial]( +https://imagej.net/plugins/trackmate/trackers/trackmate-trackastra). + +

+ icon + Docker images +

+ +Some of our models are available as docker images on [Docker Hub](https://hub.docker.com/r/bentaculum/trackastra-track/tags). Currently, we only provide CPU-based docker images. + +Track within a docker container with the following command, filling the ``: + +```bash +docker run -it -v :/data -v :/results bentaculum/trackastra-track: --input_test /data/ --detection_folder " +``` +
+📄 Show example with Cell Tracking Challenge model: +
+ +```bash +wget http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DH-GOWT1.zip +chmod -R 775 Fluo-N2DH-GOWT1 +docker pull bentaculum/trackastra-track:model.ctc-linking.ilp +docker run -it -v ./:/data -v ./:/results bentaculum/trackastra-track:model.ctc-linking.ilp --input_test data/Fluo-N2DH-GOWT1/01 --detection_folder TRA +``` + +
+ +

+ icon + Command Line Interface +

+After [installation](#installation), simply run in your terminal + +```bash +trackastra track --help +``` + +to build a command for tracking directly from images and corresponding instance segmentation masks saved on disk as two series of TIF files. + + +## Usage: Training a model on your own data To run an example - clone this repository and got into the scripts directory with `cd trackastra/scripts`. diff --git a/scripts/TAP_aug_bacteria copy.yaml b/scripts/TAP_aug_bacteria copy.yaml new file mode 100644 index 0000000..ebe15e9 --- /dev/null +++ b/scripts/TAP_aug_bacteria copy.yaml @@ -0,0 +1,97 @@ +# Forked from zih_bacteria.yaml +name: vanvliet_dinov2_aug +# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug +epochs: 750 +warmup_epochs: 5 +window: 4 +attn_dist_mode: v1 +delta_cutoff: 1 +num_encoder_layers: 4 +num_decoder_layers: 4 +causal_norm: none +d_model: 128 +dropout: 0.05 +pos_embed_per_dim: 32 +lr: 0.0001 +train_samples: 32000 +max_tokens: 2048 +batch_size: 64 +detection_folders: +- TRA +crop_size: +- 320 +- 320 +attn_positional_bias: rope +ndim: 2 +# Features config +features: pretrained_feats_aug +# features: pretrained_feats +# features: wrfeat +#### Additional parameters for pretrained_feats +augment: 3 +pretrained_n_augs: 25 +# pretrained_feats_model: facebook/sam2.1-hiera-base-plus +# pretrained_feats_model: facebookresearch/co-tracker +# pretrained_feats_model: facebook/dinov2-base +pretrained_feats_model: weigertlab/tarrow +pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2 +# pretrained_feats_model: debug/encoded_labels +reduced_pretrained_feat_dim: 128 +pretrained_feats_mode: mean_patches_exact +# pretrained_feats_mode: nearest_patch +pretrained_feats_additional_props: regionprops_small +# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding +rotate_features: true +# disable_xy_coords: true +# disable_all_coords: True +cachedir: /backup/achard/cache +outdir: /home/achard/trastra_v2 +# weight_decay: 0.01 +#### Logger +logger: wandb +wandb_project: "trackastra_v2" +# Paths config +cache: true +compress: true +distributed: false +### + +input_train: +- /home/achard/CTC_DATA/vanvliet/cib/140409-03 # +- /home/achard/CTC_DATA/vanvliet/cib/140415-08 +- /home/achard/CTC_DATA/vanvliet/recA/151027-05 +- /home/achard/CTC_DATA/vanvliet/recA/151027-06 +# - /home/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying +- /home/achard/CTC_DATA/vanvliet/recA/151028-01 +- /home/achard/CTC_DATA/vanvliet/recA/151029-05 +- /home/achard/CTC_DATA/vanvliet/recA/151029-11 +- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12 +- /home/achard/CTC_DATA/vanvliet/metA/150317-07 +- /home/achard/CTC_DATA/vanvliet/metA/150318-06 +- /home/achard/CTC_DATA/vanvliet/metA/150331-12 +- /home/achard/CTC_DATA/vanvliet/metA/151222-10 +- /home/achard/CTC_DATA/vanvliet/metA/151222-11 +- /home/achard/CTC_DATA/vanvliet/pheA/150324-03 +- /home/achard/CTC_DATA/vanvliet/pheA/150324-05 +- /home/achard/CTC_DATA/vanvliet/pheA/150325-04 +- /home/achard/CTC_DATA/vanvliet/pheA/160112-04 +- /home/achard/CTC_DATA/vanvliet/trpL/150303-01 +- /home/achard/CTC_DATA/vanvliet/trpL/150303-08 + +input_val: +- /home/achard/CTC_DATA/vanvliet/trpL/150428-08 +- /home/achard/CTC_DATA/vanvliet/trpL/151021-11 + +input_test: +# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise +- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok +- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok +- /home/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok +# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty +- /home/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving! +- /home/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set. diff --git a/scripts/cotrack3_aug_bacteria.yaml b/scripts/cotrack3_aug_bacteria.yaml new file mode 100644 index 0000000..8e8d979 --- /dev/null +++ b/scripts/cotrack3_aug_bacteria.yaml @@ -0,0 +1,96 @@ +# Forked from zih_bacteria.yaml +name: vanvliet_cotracker3_aug +# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug +epochs: 750 +warmup_epochs: 5 +window: 4 +attn_dist_mode: v1 +delta_cutoff: 1 +num_encoder_layers: 4 +num_decoder_layers: 4 +causal_norm: none +d_model: 128 +dropout: 0.05 +pos_embed_per_dim: 32 +lr: 0.0001 +train_samples: 320005 +max_tokens: 2048 +batch_size: 64 +detection_folders: +- TRA +crop_size: +- 320 +- 320 +attn_positional_bias: rope +ndim: 2 +# Features config +features: pretrained_feats_aug +# features: pretrained_feats +# features: wrfeat +#### Additional parameters for pretrained_feats +augment: 3 +pretrained_n_augs: 25 +# pretrained_feats_model: facebook/sam2.1-hiera-base-plus +pretrained_feats_model: facebookresearch/co-tracker +# pretrained_feats_model: weigertlab/tarrow +# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2 +# pretrained_feats_model: debug/encoded_labels +reduced_pretrained_feat_dim: 128 +pretrained_feats_mode: mean_patches_exact +# pretrained_feats_mode: nearest_patch +pretrained_feats_additional_props: regionprops_small +# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding +rotate_features: true +# disable_xy_coords: true +# disable_all_coords: True +cachedir: /home/achard/cache +outdir: /home/achard/trastra_v2 +# weight_decay: 0.01 +#### Logger +logger: wandb +wandb_project: "trackastra_v2" +# Paths config +cache: true +compress: true +distributed: false +### + +input_train: +- /home/achard/CTC_DATA/vanvliet/cib/140409-03 # +- /home/achard/CTC_DATA/vanvliet/cib/140415-08 +- /home/achard/CTC_DATA/vanvliet/recA/151027-05 +- /home/achard/CTC_DATA/vanvliet/recA/151027-06 +# - /home/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying +- /home/achard/CTC_DATA/vanvliet/recA/151028-01 +- /home/achard/CTC_DATA/vanvliet/recA/151029-05 +- /home/achard/CTC_DATA/vanvliet/recA/151029-11 +- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12 +- /home/achard/CTC_DATA/vanvliet/metA/150317-07 +- /home/achard/CTC_DATA/vanvliet/metA/150318-06 +- /home/achard/CTC_DATA/vanvliet/metA/150331-12 +- /home/achard/CTC_DATA/vanvliet/metA/151222-10 +- /home/achard/CTC_DATA/vanvliet/metA/151222-11 +- /home/achard/CTC_DATA/vanvliet/pheA/150324-03 +- /home/achard/CTC_DATA/vanvliet/pheA/150324-05 +- /home/achard/CTC_DATA/vanvliet/pheA/150325-04 +- /home/achard/CTC_DATA/vanvliet/pheA/160112-04 +- /home/achard/CTC_DATA/vanvliet/trpL/150303-01 +- /home/achard/CTC_DATA/vanvliet/trpL/150303-08 + +input_val: +- /home/achard/CTC_DATA/vanvliet/trpL/150428-08 +- /home/achard/CTC_DATA/vanvliet/trpL/151021-11 + +input_test: +# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise +- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok +- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok +- /home/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok +# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty +- /home/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving! +- /home/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set. diff --git a/scripts/deepcell/SAM_deepcell.yaml b/scripts/deepcell/SAM_deepcell.yaml new file mode 100644 index 0000000..60d2c66 --- /dev/null +++ b/scripts/deepcell/SAM_deepcell.yaml @@ -0,0 +1,100 @@ +name: deepcell_cotrack_aug_no_drpt +# model: /home/achard/tratra_runs_eval/2025-05-28_10-04-50_vanvliet_sam21_aug +epochs: 1000 +warmup_epochs: 5 +window: 4 +attn_dist_mode: v1 +delta_cutoff: 1 +num_encoder_layers: 4 +num_decoder_layers: 4 +causal_norm: none +d_model: 128 +# dropout: 0.05 +pos_embed_per_dim: 32 +lr: 0.0001 +train_samples: 32000 +max_tokens: 2048 +batch_size: 128 +detection_folders: +- TRA +crop_size: +- 320 +- 320 +attn_positional_bias: rope +ndim: 2 +# Features config +features: pretrained_feats_aug +# features: pretrained_feats +# features: wrfeat +#### Additional parameters for pretrained_feats +augment: 3 +pretrained_n_augs: 25 +pretrained_feats_model: facebook/sam-vit-base +reduced_pretrained_feat_dim: 128 +# pretrained_feats_model: weigertlab/tarrow +# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2 +# pretrained_feats_model: debug/encoded_labels +pretrained_feats_mode: mean_patches_exact +# pretrained_feats_mode: nearest_patch +pretrained_feats_additional_props: regionprops_small +# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding +# rotate_features: true +# disable_xy_coords: true +# disable_all_coords: True +cachedir: /backup/achard/cache +outdir: /home/achard/trastra_runs_v2 +#### Logger +logger: wandb +wandb_project: "trackstra_v2_deepcell" +# Paths config +cache: True +compress: True +distributed: True + +input_train: +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/00 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/01 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/02 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/03 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/04 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/05 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/06 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/07 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/08 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/09 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/10 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/11 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/12 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/13 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/14 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/15 +# - /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/16 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/16_corrected +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/17 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/18 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/19 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/20 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/23 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/25 + +input_val: +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/21 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/22 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/24 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/44 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/45 + + +input_test: +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/00 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/01 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/02 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/03 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/04 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/05 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/06 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/07 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/08 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/09 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/10 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/11 diff --git a/scripts/deepcell/cotracker3_deepcell.yaml b/scripts/deepcell/cotracker3_deepcell.yaml new file mode 100644 index 0000000..d8dfd53 --- /dev/null +++ b/scripts/deepcell/cotracker3_deepcell.yaml @@ -0,0 +1,100 @@ +name: deepcell_cotrack_aug_no_drpt +# model: /home/achard/tratra_runs_eval/2025-05-28_10-04-50_vanvliet_sam21_aug +epochs: 1000 +warmup_epochs: 5 +window: 4 +attn_dist_mode: v1 +delta_cutoff: 1 +num_encoder_layers: 4 +num_decoder_layers: 4 +causal_norm: none +d_model: 128 +dropout: 0.05 +pos_embed_per_dim: 32 +lr: 0.0001 +train_samples: 32000 +max_tokens: 2048 +batch_size: 128 +detection_folders: +- TRA +crop_size: +- 320 +- 320 +attn_positional_bias: rope +ndim: 2 +# Features config +features: pretrained_feats_aug +# features: pretrained_feats +# features: wrfeat +#### Additional parameters for pretrained_feats +augment: 3 +pretrained_n_augs: 25 +pretrained_feats_model: facebookresearch/co-tracker +reduced_pretrained_feat_dim: 128 +# pretrained_feats_model: weigertlab/tarrow +# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2 +# pretrained_feats_model: debug/encoded_labels +pretrained_feats_mode: mean_patches_exact +# pretrained_feats_mode: nearest_patch +pretrained_feats_additional_props: regionprops_small +# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding +# rotate_features: true +# disable_xy_coords: true +# disable_all_coords: True +cachedir: /backup/achard/cache +outdir: /home/achard/trastra_runs_v2 +#### Logger +logger: wandb +wandb_project: "trackstra_v2_deepcell" +# Paths config +cache: True +compress: True +distributed: True + +input_train: +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/00 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/01 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/02 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/03 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/04 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/05 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/06 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/07 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/08 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/09 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/10 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/11 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/12 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/13 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/14 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/15 +# - /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/16 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/16_corrected +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/17 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/18 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/19 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/20 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/23 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/25 + +input_val: +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/21 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/22 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/24 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/44 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/45 + + +input_test: +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/00 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/01 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/02 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/03 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/04 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/05 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/06 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/07 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/08 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/09 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/10 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/test/11 diff --git a/scripts/deepcell/sam21_deepcell.yaml b/scripts/deepcell/sam21_deepcell.yaml new file mode 100644 index 0000000..c855032 --- /dev/null +++ b/scripts/deepcell/sam21_deepcell.yaml @@ -0,0 +1,185 @@ +name: deepcell_sam21_aug +epochs: 1000 +warmup_epochs: 5 +window: 4 +attn_dist_mode: v1 +num_encoder_layers: 4 +num_decoder_layers: 4 +causal_norm: none +d_model: 256 +dropout: 0.05 +pos_embed_per_dim: 32 +lr: 0.0001 +train_samples: 32000 +max_tokens: 2048 +batch_size: 128 +detection_folders: +- TRA +crop_size: +- 256 +- 256 +attn_positional_bias: rope +ndim: 2 +# Features config +features: pretrained_feats_aug +#### Additional parameters for pretrained_feats +augment: 3 +pretrained_n_augs: 25 +pretrained_feats_model: facebook/sam2.1-hiera-base-plus +reduced_pretrained_feat_dim: 128 +pretrained_feats_mode: mean_patches_exact +pretrained_feats_additional_props: regionprops_small +rotate_features: true +cachedir: /backup/achard/cache +outdir: /home/achard/trastra_v3_deepcell +# weight_decay: 0.01 +#### Logger +logger: wandb +wandb_project: "trackastra_v3_deepcell" +# Paths config + +### DEEPCELL SPECIFIC +delta_cutoff: 2 +spatial_pos_cutoff: 96 +div_upweight: 2 +### +cache: true +compress: true +distributed: false +num_workers: 24 + +input_train: +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/00 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/01 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/02 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/03 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/04 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/05 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/06 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/07 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/08 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/09 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/10 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/11 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/12 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/13 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/14 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/15 +# - /backup/achard/CTC_DATA/deepcell/deepcell_full/train/16 +- /backup/achard/CTC_DATA/deepcell/deepcell_orig/train/16_corrected +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/17 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/18 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/19 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/20 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/21 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/22 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/23 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/24 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/25 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/26 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/27 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/28 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/29 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/30 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/31 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/32 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/33 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/34 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/35 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/36 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/37 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/38 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/39 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/40 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/41 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/42 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/43 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/44 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/45 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/46 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/47 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/48 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/49 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/50 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/51 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/52 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/53 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/54 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/55 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/56 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/57 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/58 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/59 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/60 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/61 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/62 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/63 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/64 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/65 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/66 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/67 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/68 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/69 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/70 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/71 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/72 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/73 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/74 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/75 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/76 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/77 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/78 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/79 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/80 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/81 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/82 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/83 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/84 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/85 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/86 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/87 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/88 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/89 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/train/90 +input_val: +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/00 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/01 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/02 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/03 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/04 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/05 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/06 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/07 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/08 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/09 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/10 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/11 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/12 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/13 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/14 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/15 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/16 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/17 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/18 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/19 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/20 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/21 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/22 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/23 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/24 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/25 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/val/26 + +input_test: +- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/00 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/01 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/02 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/03 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/04 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/05 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/06 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/07 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/08 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/09 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/10 +- /backup/achard/CTC_DATA/deepcell/deepcell_full/test/11 \ No newline at end of file diff --git a/scripts/dinov2_aug_bacteria.yaml b/scripts/dinov2_aug_bacteria.yaml new file mode 100644 index 0000000..150434d --- /dev/null +++ b/scripts/dinov2_aug_bacteria.yaml @@ -0,0 +1,97 @@ +# Forked from zih_bacteria.yaml +name: vanvliet_dinov2_aug +# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug +epochs: 750 +warmup_epochs: 5 +window: 4 +attn_dist_mode: v1 +delta_cutoff: 1 +num_encoder_layers: 4 +num_decoder_layers: 4 +causal_norm: none +d_model: 128 +dropout: 0.05 +pos_embed_per_dim: 32 +lr: 0.0001 +train_samples: 32000 +max_tokens: 2048 +batch_size: 64 +detection_folders: +- TRA +crop_size: +- 320 +- 320 +attn_positional_bias: rope +ndim: 2 +# Features config +features: pretrained_feats_aug +# features: pretrained_feats +# features: wrfeat +#### Additional parameters for pretrained_feats +augment: 3 +pretrained_n_augs: 25 +# pretrained_feats_model: facebook/sam2.1-hiera-base-plus +# pretrained_feats_model: facebookresearch/co-tracker +pretrained_feats_model: facebook/dinov2-base +# pretrained_feats_model: weigertlab/tarrow +# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2 +# pretrained_feats_model: debug/encoded_labels +reduced_pretrained_feat_dim: 128 +pretrained_feats_mode: mean_patches_exact +# pretrained_feats_mode: nearest_patch +pretrained_feats_additional_props: regionprops_small +# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding +rotate_features: true +# disable_xy_coords: true +# disable_all_coords: True +cachedir: /backup/achard/cache +outdir: /home/achard/trastra_v2 +# weight_decay: 0.01 +#### Logger +logger: wandb +wandb_project: "trackastra_v2" +# Paths config +cache: true +compress: true +distributed: false +### + +input_train: +- /home/achard/CTC_DATA/vanvliet/cib/140409-03 # +- /home/achard/CTC_DATA/vanvliet/cib/140415-08 +- /home/achard/CTC_DATA/vanvliet/recA/151027-05 +- /home/achard/CTC_DATA/vanvliet/recA/151027-06 +# - /home/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying +- /home/achard/CTC_DATA/vanvliet/recA/151028-01 +- /home/achard/CTC_DATA/vanvliet/recA/151029-05 +- /home/achard/CTC_DATA/vanvliet/recA/151029-11 +- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12 +- /home/achard/CTC_DATA/vanvliet/metA/150317-07 +- /home/achard/CTC_DATA/vanvliet/metA/150318-06 +- /home/achard/CTC_DATA/vanvliet/metA/150331-12 +- /home/achard/CTC_DATA/vanvliet/metA/151222-10 +- /home/achard/CTC_DATA/vanvliet/metA/151222-11 +- /home/achard/CTC_DATA/vanvliet/pheA/150324-03 +- /home/achard/CTC_DATA/vanvliet/pheA/150324-05 +- /home/achard/CTC_DATA/vanvliet/pheA/150325-04 +- /home/achard/CTC_DATA/vanvliet/pheA/160112-04 +- /home/achard/CTC_DATA/vanvliet/trpL/150303-01 +- /home/achard/CTC_DATA/vanvliet/trpL/150303-08 + +input_val: +- /home/achard/CTC_DATA/vanvliet/trpL/150428-08 +- /home/achard/CTC_DATA/vanvliet/trpL/151021-11 + +input_test: +# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise +- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok +- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok +- /home/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok +# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty +- /home/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving! +- /home/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set. diff --git a/scripts/example_config.yaml b/scripts/example_config.yaml index 966480f..547491b 100644 --- a/scripts/example_config.yaml +++ b/scripts/example_config.yaml @@ -1,20 +1,97 @@ -batch_size: 8 -crop_size: -- 256 -- 256 +# Forked from zih_bacteria.yaml +name: vanvliet_sam21_quiet_softmax_aug +# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug +epochs: 1000 +warmup_epochs: 5 +window: 4 +attn_dist_mode: v1 +delta_cutoff: 1 +num_encoder_layers: 4 +num_decoder_layers: 4 +# causal_norm: none +causal_norm: quiet_softmax +d_model: 128 +dropout: 0.0 +pos_embed_per_dim: 32 +lr: 0.0001 +train_samples: 32000 +max_tokens: 2048 +batch_size: 60 detection_folders: - TRA -dropout: 0.01 -example_images: False # Slow +crop_size: +- 320 +- 320 +attn_positional_bias: rope +ndim: 2 +# Features config +features: pretrained_feats_aug +# features: pretrained_feats +# features: wrfeat +#### Additional parameters for pretrained_feats +augment: 3 +pretrained_n_augs: 25 +pretrained_feats_model: facebook/sam2.1-hiera-base-plus +# pretrained_feats_model: facebookresearch/co-tracker +# pretrained_feats_model: weigertlab/tarrow +# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2 +# pretrained_feats_model: debug/encoded_labels +reduced_pretrained_feat_dim: 128 +pretrained_feats_mode: mean_patches_exact +# pretrained_feats_mode: nearest_patch +pretrained_feats_additional_props: regionprops_small +# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding +rotate_features: true +# disable_xy_coords: true +# disable_all_coords: True +cachedir: /home/achard/cache +outdir: /home/achard/trastra_v2 +# weight_decay: 0.01 +#### Logger +logger: wandb +wandb_project: "trackastra_v2" +# Paths config +cache: true +compress: true +distributed: false +### + input_train: -- data/ctc/Fluo-N2DL-HeLa/02 +- /backup/achard/CTC_DATA/vanvliet/cib/140409-03 # +- /backup/achard/CTC_DATA/vanvliet/cib/140415-08 +- /backup/achard/CTC_DATA/vanvliet/recA/151027-05 +- /backup/achard/CTC_DATA/vanvliet/recA/151027-06 +# - /backup/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying +- /backup/achard/CTC_DATA/vanvliet/recA/151028-01 +- /backup/achard/CTC_DATA/vanvliet/recA/151029-05 +- /backup/achard/CTC_DATA/vanvliet/recA/151029-11 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12 +- /backup/achard/CTC_DATA/vanvliet/metA/150317-07 +- /backup/achard/CTC_DATA/vanvliet/metA/150318-06 +- /backup/achard/CTC_DATA/vanvliet/metA/150331-12 +- /backup/achard/CTC_DATA/vanvliet/metA/151222-10 +- /backup/achard/CTC_DATA/vanvliet/metA/151222-11 +- /backup/achard/CTC_DATA/vanvliet/pheA/150324-03 +- /backup/achard/CTC_DATA/vanvliet/pheA/150324-05 +- /backup/achard/CTC_DATA/vanvliet/pheA/150325-04 +- /backup/achard/CTC_DATA/vanvliet/pheA/160112-04 +- /backup/achard/CTC_DATA/vanvliet/trpL/150303-01 +- /backup/achard/CTC_DATA/vanvliet/trpL/150303-08 + input_val: -- data/ctc/Fluo-N2DL-HeLa/01 -max_tokens: 2048 -name: example -ndim: 2 -num_decoder_layers: 5 -num_encoder_layers: 5 -outdir: runs -distributed: False -window: 6 +- /backup/achard/CTC_DATA/vanvliet/trpL/150428-08 +- /backup/achard/CTC_DATA/vanvliet/trpL/151021-11 + +input_test: +# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise +- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok +- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok +- /backup/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok +# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty +- /backup/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving! +- /backup/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set. diff --git a/scripts/sam21_aug_bacteria.yaml b/scripts/sam21_aug_bacteria.yaml new file mode 100644 index 0000000..9128751 --- /dev/null +++ b/scripts/sam21_aug_bacteria.yaml @@ -0,0 +1,96 @@ +# Forked from zih_bacteria.yaml +name: vanvliet_sam21_aug +# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug +epochs: 1000 +warmup_epochs: 5 +window: 4 +attn_dist_mode: v1 +delta_cutoff: 1 +num_encoder_layers: 4 +num_decoder_layers: 4 +causal_norm: none +d_model: 256 +dropout: 0.05 +pos_embed_per_dim: 32 +lr: 0.0001 +train_samples: 32000 +max_tokens: 2048 +batch_size: 55 +detection_folders: +- TRA +crop_size: +- 320 +- 320 +attn_positional_bias: rope +ndim: 2 +# Features config +features: pretrained_feats_aug +# features: pretrained_feats +# features: wrfeat +#### Additional parameters for pretrained_feats +augment: 3 +pretrained_n_augs: 25 +pretrained_feats_model: facebook/sam2.1-hiera-base-plus +# pretrained_feats_model: facebookresearch/co-tracker +# pretrained_feats_model: weigertlab/tarrow +# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2 +# pretrained_feats_model: debug/encoded_labels +reduced_pretrained_feat_dim: 128 +pretrained_feats_mode: mean_patches_exact +# pretrained_feats_mode: nearest_patch +pretrained_feats_additional_props: regionprops_small +# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding +rotate_features: true +# disable_xy_coords: true +# disable_all_coords: True +cachedir: /backup/achard/cache +outdir: /home/achard/trastra_v2 +# weight_decay: 0.01 +#### Logger +logger: wandb +wandb_project: "trackastra_v2" +# Paths config +cache: true +compress: true +distributed: false +### + +input_train: +- /backup/achard/CTC_DATA/vanvliet/cib/140409-03 # +- /backup/achard/CTC_DATA/vanvliet/cib/140415-08 +- /backup/achard/CTC_DATA/vanvliet/recA/151027-05 +- /backup/achard/CTC_DATA/vanvliet/recA/151027-06 +# - /backup/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying +- /backup/achard/CTC_DATA/vanvliet/recA/151028-01 +- /backup/achard/CTC_DATA/vanvliet/recA/151029-05 +- /backup/achard/CTC_DATA/vanvliet/recA/151029-11 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12 +- /backup/achard/CTC_DATA/vanvliet/metA/150317-07 +- /backup/achard/CTC_DATA/vanvliet/metA/150318-06 +- /backup/achard/CTC_DATA/vanvliet/metA/150331-12 +- /backup/achard/CTC_DATA/vanvliet/metA/151222-10 +- /backup/achard/CTC_DATA/vanvliet/metA/151222-11 +- /backup/achard/CTC_DATA/vanvliet/pheA/150324-03 +- /backup/achard/CTC_DATA/vanvliet/pheA/150324-05 +- /backup/achard/CTC_DATA/vanvliet/pheA/150325-04 +- /backup/achard/CTC_DATA/vanvliet/pheA/160112-04 +- /backup/achard/CTC_DATA/vanvliet/trpL/150303-01 +- /backup/achard/CTC_DATA/vanvliet/trpL/150303-08 + +input_val: +- /backup/achard/CTC_DATA/vanvliet/trpL/150428-08 +- /backup/achard/CTC_DATA/vanvliet/trpL/151021-11 + +input_test: +# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise +- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok +- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok +- /backup/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok +# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty +- /backup/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving! +- /backup/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set. diff --git a/scripts/sam_aug_bacteria.yaml b/scripts/sam_aug_bacteria.yaml new file mode 100644 index 0000000..17d68b7 --- /dev/null +++ b/scripts/sam_aug_bacteria.yaml @@ -0,0 +1,96 @@ +# Forked from zih_bacteria.yaml +name: vanvliet_SAM_aug +# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug +epochs: 750 +warmup_epochs: 5 +window: 4 +attn_dist_mode: v1 +delta_cutoff: 1 +num_encoder_layers: 4 +num_decoder_layers: 4 +causal_norm: none +d_model: 128 +dropout: 0.05 +pos_embed_per_dim: 32 +lr: 0.0001 +train_samples: 32000 +max_tokens: 2048 +batch_size: 64 +detection_folders: +- TRA +crop_size: +- 320 +- 320 +attn_positional_bias: rope +ndim: 2 +# Features config +features: pretrained_feats_aug +# features: pretrained_feats +# features: wrfeat +#### Additional parameters for pretrained_feats +augment: 3 +pretrained_n_augs: 25 +pretrained_feats_model: facebook/sam-vit-base +# pretrained_feats_model: facebookresearch/co-tracker +# pretrained_feats_model: weigertlab/tarrow +# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2 +# pretrained_feats_model: debug/encoded_labels +reduced_pretrained_feat_dim: 128 +pretrained_feats_mode: mean_patches_exact +# pretrained_feats_mode: nearest_patch +pretrained_feats_additional_props: regionprops_small +# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding +rotate_features: true +# disable_xy_coords: true +# disable_all_coords: True +cachedir: /home/achard/cache +outdir: /home/achard/trastra_v2 +# weight_decay: 0.01 +#### Logger +logger: wandb +wandb_project: "trackastra_v2" +# Paths config +cache: true +compress: true +distributed: false +### + +input_train: +- /home/achard/CTC_DATA/vanvliet/cib/140409-03 # +- /home/achard/CTC_DATA/vanvliet/cib/140415-08 +- /home/achard/CTC_DATA/vanvliet/recA/151027-05 +- /home/achard/CTC_DATA/vanvliet/recA/151027-06 +# - /home/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying +- /home/achard/CTC_DATA/vanvliet/recA/151028-01 +- /home/achard/CTC_DATA/vanvliet/recA/151029-05 +- /home/achard/CTC_DATA/vanvliet/recA/151029-11 +- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11 +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12 +- /home/achard/CTC_DATA/vanvliet/metA/150317-07 +- /home/achard/CTC_DATA/vanvliet/metA/150318-06 +- /home/achard/CTC_DATA/vanvliet/metA/150331-12 +- /home/achard/CTC_DATA/vanvliet/metA/151222-10 +- /home/achard/CTC_DATA/vanvliet/metA/151222-11 +- /home/achard/CTC_DATA/vanvliet/pheA/150324-03 +- /home/achard/CTC_DATA/vanvliet/pheA/150324-05 +- /home/achard/CTC_DATA/vanvliet/pheA/150325-04 +- /home/achard/CTC_DATA/vanvliet/pheA/160112-04 +- /home/achard/CTC_DATA/vanvliet/trpL/150303-01 +- /home/achard/CTC_DATA/vanvliet/trpL/150303-08 + +input_val: +- /home/achard/CTC_DATA/vanvliet/trpL/150428-08 +- /home/achard/CTC_DATA/vanvliet/trpL/151021-11 + +input_test: +# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise +- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok +- /home/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok +- /home/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok +- /home/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok +# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty +- /home/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving! +- /home/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set. diff --git a/scripts/train.py b/scripts/train.py index 2b74698..a9f2de6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -16,6 +16,7 @@ import configargparse import git +import humanize import lightning as pl import numpy as np import psutil @@ -25,7 +26,6 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from lightning.pytorch.profilers import PyTorchProfiler from lightning.pytorch.utilities.rank_zero import rank_zero_only -from numerize import numerize from skimage.morphology import binary_dilation, disk from torch.optim.lr_scheduler import LRScheduler from torchvision.utils import make_grid @@ -36,6 +36,12 @@ CTCData, collate_sequence_padding, ) +from trackastra.data.pretrained_features import ( + AVAILABLE_PRETRAINED_BACKBONES, + PretrainedFeatsExtractionMode, + PretrainedFeatureExtractorConfig, +) +from trackastra.data.wrfeat import _PROPERTIES, DEFAULT_PROPERTIES, WRFeatures from trackastra.model import TrackingTransformer from trackastra.utils import ( blockwise_causal_norm, @@ -48,12 +54,13 @@ str2bool, ) -logging.basicConfig(level=logging.INFO) +# logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) warnings.simplefilter(action="ignore", category=FutureWarning) -device = torch.device("cuda" if torch.cuda.is_available() else "mps") +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") np.seterr(all="ignore") @@ -83,7 +90,6 @@ def __init__( max_epochs, cosine_final: float = 0.001, last_epoch=-1, - verbose=False, ): """Use cosine_final to switch on/off the cosine annealing. @@ -93,7 +99,7 @@ def __init__( self.warmup_epochs = warmup_epochs self.max_epochs = max_epochs self.cosine_final = cosine_final - super().__init__(optimizer, last_epoch, verbose) + super().__init__(optimizer, last_epoch) def get_lr(self): if not self._get_lr_called_within_step: @@ -182,6 +188,8 @@ def __init__( tracking_frequency: int = -1, # log TRA metrics every that epochs batch_val_tb_idx: int = 0, # the batch index to visualize in tensorboard div_upweight: float = 20, + # per_param_clipping: bool = False, + weight_decay: float = 0.01, ): super().__init__() @@ -196,30 +204,60 @@ def __init__( self.batch_val_tb = None self.lr = learning_rate + self.weight_decay = weight_decay self.tracking_frequency = tracking_frequency self.warmup_epochs = warmup_epochs self.max_epochs = max_epochs self.div_upweight = div_upweight - + # self.per_param_clipping = per_param_clipping + def _common_step(self, batch, eps=torch.finfo(torch.float32).eps): + # torch.autograd.set_detect_anomaly(True) feats = batch["features"] + try: + pretrained_feats = batch["pretrained_features"] + except KeyError: + pretrained_feats = None coords = batch["coords"] A = batch["assoc_matrix"] timepoints = batch["timepoints"] padding_mask = batch["padding_mask"] padding_mask = padding_mask.bool() + + if feats is not None: + if torch.any(torch.isnan(feats)): + nan_dims = torch.any(torch.isnan(feats), dim=-1) + raise ValueError("NaN in features in dimensions: ", nan_dims) + if pretrained_feats is not None: + if torch.any(torch.isnan(pretrained_feats)): + nan_dims = torch.any(torch.isnan(pretrained_feats), dim=-1) + raise ValueError("NaN in pretrained features in dimensions: ", nan_dims) + if torch.any(torch.isnan(coords)): + raise ValueError("NaN in coords") + + A_pred = self.model(coords, feats, pretrained_feats, padding_mask=padding_mask) + + if self.model.norms: # if dict is not empty, log each entry to wandb + for key, value in self.model.norms.items(): + # check wandb runner is initialized + self.log_dict( + {f"norms/{key}": value}, on_step=True, on_epoch=False, sync_dist=True + ) - A_pred = self.model(coords, feats, padding_mask=padding_mask) # remove inf values that might happen due to float16 numerics A_pred.clamp_(torch.finfo(torch.float16).min, torch.finfo(torch.float16).max) + # above call might interfere with backward as it is an inplace operation + # A_pred = A_pred.clamp(torch.finfo(torch.float16).min, torch.finfo(torch.float16).max) mask_invalid = torch.logical_or( padding_mask.unsqueeze(1), padding_mask.unsqueeze(2) ) A_pred[mask_invalid] = 0 + # above call might interfere with backward as it is an inplace operation in "linear" causal norm + # A_pred = A_pred.masked_fill(mask_invalid, 0) loss = self.criterion(A_pred, A) - + if self.causal_norm != "none": # TODO speedup: I could softmax only the part of the matrix (upper triangular) that is not masked out A_pred_soft = torch.stack( @@ -230,10 +268,11 @@ def _common_step(self, batch, eps=torch.finfo(torch.float32).eps): for _A, _t, _m in zip(A_pred, timepoints, mask_invalid) ] ) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(enabled=False, device_type=str(self.device)): if len(A) > 0: # debug if torch.any(torch.isnan(A_pred_soft)): + print(A_pred) print( "AAAA pred", @@ -252,9 +291,17 @@ def _common_step(self, batch, eps=torch.finfo(torch.float32).eps): timepoints=timepoints.detach().cpu().numpy(), ) + if A_pred_soft.dtype != A.dtype: + logger.warning( + "A_pred_soft has different dtype than A, casting to A.dtype" + ) + A_pred_soft = A_pred_soft.to(A.dtype) # Keep the non-softmaxed loss for numerical stability loss = 0.01 * loss + self.criterion_softmax(A_pred_soft, A) + if torch.any(torch.isnan(loss)): + raise ValueError("NaN after loss summing") + # Reweighting does not need gradients with torch.no_grad(): block_sum1 = torch.stack( @@ -287,6 +334,9 @@ def _common_step(self, batch, eps=torch.finfo(torch.float32).eps): mask.sum(dim=(1, 2), keepdim=True) + eps ) loss_per_sample = loss_normalized.sum(dim=(1, 2)) + + if torch.any(torch.isnan(loss_per_sample)): + raise ValueError("NaN in loss_per_sample after reduction") # Hack: weight larger samples a little more... prefactor = torch.pow(mask.sum(dim=(1, 2)), 0.2) @@ -313,7 +363,7 @@ def checkpoint_path(self, logdir): return None def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-5) + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) return dict( optimizer=optimizer, lr_scheduler=WarmupCosineLRScheduler( @@ -321,12 +371,26 @@ def configure_optimizers(self): ), ) + def on_before_optimizer_step(self, optimizer): + # self.trainer.precision_plugin.scaler.unscale_(optimizer) + # from torch.nn.utils import clip_grad_norm_ + # if self.per_param_clipping: + # for param in self.model.parameters(): + # if param.grad is not None: + # clip_grad_norm_(param, max_norm=1.0) + # Compute the 2-norm for each layer + # If using mixed precision, the gradients are already unscaled here + from lightning.pytorch.utilities import grad_norm + norms = grad_norm(self.model, norm_type=2) + self.log_dict(norms) + def training_step(self, batch, batch_idx): out = self._common_step(batch) loss = out["loss"] if torch.isnan(loss): - print("NaN loss, skipping") - return None + # print("NaN loss, skipping") + # return None + raise ValueError("NaN loss") self.log( "train_loss", @@ -335,6 +399,7 @@ def training_step(self, batch, batch_idx): on_step=False, on_epoch=True, sync_dist=True, + batch_size=batch["coords"].shape[0], ) # self.train_loss.append(loss) @@ -368,7 +433,7 @@ def validation_step(self, batch, batch_idx): loss = out["loss"] if torch.isnan(loss): print("NaN loss, skipping") - return None + raise ValueError("NaN loss") self.log( "val_loss", @@ -377,6 +442,7 @@ def validation_step(self, batch, batch_idx): on_step=False, on_epoch=True, sync_dist=True, + batch_size=batch["coords"].shape[0], ) # self.val_loss.append(loss) @@ -727,6 +793,12 @@ def _init_wandb(project, name, config): _ = wandb.init(project=project, name=name, config=config) +@rank_zero_only +def create_wandb_logger(run_name, wandb_project): + wandb_logger = WandbLogger(name=run_name, project=wandb_project) + return wandb_logger + + def train(args): args.seed = seed(args.seed) if args.model is None: @@ -766,6 +838,8 @@ def train(args): train_logger = TensorBoardLogger(logdir, name="tb") elif args.logger == "wandb": train_logger = WandbLogger(name=run_name, project=args.wandb_project) + # train_logger = create_wandb_logger(run_name, args.wandb_project) + # train_logger.log_hyperparams(training_args) # init here to get an alert on job failure even before training _init_wandb(project=args.wandb_project, name=run_name, config=vars(args)) @@ -782,6 +856,42 @@ def train(args): raise ValueError( f'Logdir {logdir} exists, set "--resume t" if you want to overwrite' ) + + pretrained_config = None + if args.features == "pretrained_feats" or args.features == "pretrained_feats_aug": + if args.pretrained_feats_model is None: + raise ValueError( + "Pretrained model must be defined if pretrained features are in use." + f"Available models: {AVAILABLE_PRETRAINED_BACKBONES.keys()}" + ) + if args.pretrained_feats_model not in AVAILABLE_PRETRAINED_BACKBONES: + raise ValueError( + f"Unknown pretrained model {args.pretrained_feats_model}, available: {AVAILABLE_PRETRAINED_BACKBONES.keys()}" + ) + if args.pretrained_feats_mode is None: + raise ValueError( + "Pretrained mode must be defined if pretrained features are in use." + ) + if args.features == "pretrained_feats_aug" and args.pretrained_n_augs is None: + raise ValueError( + "Number of augmentated copies must be defined if using augmented pretrained features." + ) + emb_save_path = None if args.cachedir is None else Path(args.cachedir).resolve() + if not emb_save_path.exists(): + emb_save_path.mkdir(parents=False, exist_ok=True) + # pca_save_path = ( + # Path(logdir) / "pca" if args.pretrained_feats_pca_ncomp else None + # ) + + pretrained_config = PretrainedFeatureExtractorConfig( + model_name=args.pretrained_feats_model, + mode=args.pretrained_feats_mode, + save_path=emb_save_path, + additional_features=args.pretrained_feats_additional_props, + model_path=args.pretrained_model_path, + # pca_components=args.pretrained_feats_pca_ncomp, + # pca_preprocessor_path=pca_save_path, + ) n_gpus = torch.cuda.device_count() if args.distributed else 1 if args.preallocate: @@ -801,10 +911,15 @@ def train(args): sanity_dist=args.sanity_dist, crop_size=args.crop_size, compress=args.compress, + pretrained_backbone_config=pretrained_config, + pretrained_n_augmentations=args.pretrained_n_augs, + rotate_features=args.rotate_features, ) dummy_model = TrackingTransformer( coord_dim=dummy_data.ndim, feat_dim=dummy_data.feat_dim, + pretrained_feat_dim=dummy_data.pretrained_feat_dim, + reduced_pretrained_feat_dim=args.reduced_pretrained_feat_dim, d_model=args.d_model, pos_embed_per_dim=args.pos_embed_per_dim, feat_embed_per_dim=args.feat_embed_per_dim, @@ -817,6 +932,8 @@ def train(args): attn_positional_bias_n_spatial=args.attn_positional_bias_n_spatial, attn_dist_mode=args.attn_dist_mode, causal_norm=args.causal_norm, + disable_xy_coords=args.disable_xy_coords, + disable_all_coords=args.disable_all_coords, ) dummy_model_lightning = WrappedLightningModule( @@ -829,6 +946,8 @@ def train(args): tracking_frequency=args.tracking_frequency, batch_val_tb_idx=0, div_upweight=args.div_upweight, + # per_param_clipping=args.clip_grad_per_param, + weight_decay=args.weight_decay, ) dummy_model_lightning.to(device) preallocate_memory( @@ -852,7 +971,7 @@ def train(args): if args.only_prechecks: return locals() - + dataset_kwargs = dict( ndim=args.ndim, detection_folders=args.detection_folders, @@ -864,6 +983,9 @@ def train(args): sanity_dist=args.sanity_dist, crop_size=args.crop_size, compress=args.compress, + pretrained_backbone_config=pretrained_config, + pretrained_n_augmentations=args.pretrained_n_augs, + rotate_features=args.rotate_features, ) sampler_kwargs = dict( batch_size=args.batch_size, @@ -879,12 +1001,12 @@ def train(args): pin_memory=True, collate_fn=collate_sequence_padding, ) - + # Sampler gets wrapped with distributed sampler, which cannot sample with replacement datamodule = BalancedDataModule( input_train=args.input_train, input_val=args.input_val, - cachedir=args.cachedir, + cachedir=args.cachedir if args.cache else None, augment=args.augment, distributed=args.distributed, dataset_kwargs=dataset_kwargs, @@ -905,15 +1027,15 @@ def train(args): ) callbacks.append(pl.pytorch.callbacks.Timer(interval="epoch")) - # Mostly for stopping broken runs - callbacks.append( - pl.pytorch.callbacks.EarlyStopping( - monitor="val_loss", - patience=args.epochs // 6, - mode="min", - verbose=True, - ) - ) + # # Mostly for stopping broken runs + # callbacks.append( + # pl.pytorch.callbacks.EarlyStopping( + # monitor="val_loss", + # patience=args.epochs // 6, + # mode="min", + # verbose=True, + # ) + # ) if args.example_images: callbacks.append(ExampleImages()) @@ -932,13 +1054,23 @@ def train(args): else: model = TrackingTransformer.from_folder(fpath, args=args) else: - feat_dim = 0 if args.features == "none" else 7 if args.ndim == 2 else 12 + # feat_dim = 0 if args.features == "none" else 7 if args.ndim == 2 else 12 + if args.features == "pretrained_feats" or args.features == "pretrained_feats_aug": # TODO find a way to truly automate this + feat_dim = pretrained_config.additional_feat_dim + elif args.features == "wrfeat": + feat_dim = WRFeatures.PROPERTIES_DIMS[DEFAULT_PROPERTIES][args.ndim] + else: + feat_dim = CTCData.get_feat_dim(args.features, args.ndim) + + pretrained_feat_dim = 0 if pretrained_config is None else pretrained_config.feat_dim + model = TrackingTransformer( # coord_dim=datasets["train"].datasets[0].ndim, coord_dim=args.ndim, # feat_dim=datasets["train"].datasets[0].feat_dim, - # FIXME hardcoded feat_dim feat_dim=feat_dim, + pretrained_feat_dim=pretrained_feat_dim, + reduced_pretrained_feat_dim=args.reduced_pretrained_feat_dim, d_model=args.d_model, pos_embed_per_dim=args.pos_embed_per_dim, feat_embed_per_dim=args.feat_embed_per_dim, @@ -951,6 +1083,8 @@ def train(args): attn_positional_bias_n_spatial=args.attn_positional_bias_n_spatial, attn_dist_mode=args.attn_dist_mode, causal_norm=args.causal_norm, + disable_xy_coords=args.disable_xy_coords, + disable_all_coords=args.disable_all_coords, ) model_lightning = WrappedLightningModule( @@ -963,6 +1097,8 @@ def train(args): tracking_frequency=args.tracking_frequency, batch_val_tb_idx=batch_val_tb_idx, div_upweight=args.div_upweight, + # per_param_clipping=args.clip_grad_per_param, + weight_decay=args.weight_decay, ) # Compiling does not work! # model_lightning = torch.compile(model_lightning) @@ -983,13 +1119,14 @@ def train(args): tracking_frequency=args.tracking_frequency, batch_val_tb_idx=batch_val_tb_idx, div_upweight=args.div_upweight, + weight_decay=args.weight_decay, ) else: logging.warning(f"No checkpoint found in {logdir}") model_lightning.to(device) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - logging.info(f"Model has {numerize.numerize(num_params)} parameters") + logging.info(f"Model has {humanize.intword(num_params)} parameters") if args.distributed: strategy = "ddp" @@ -1003,16 +1140,24 @@ def train(args): else: profiler = None + import platform + + from lightning.pytorch.strategies import DDPStrategy + + if platform.system() == "Windows": + strategy = DDPStrategy(process_group_backend="gloo") + trainer = pl.Trainer( - accelerator="cuda", + accelerator="gpu" if torch.cuda.is_available() else "cpu", strategy=strategy, - devices=n_gpus, + devices=n_gpus if torch.cuda.is_available() else 1, precision="16-mixed" if args.mixedp else 32, logger=train_logger, num_nodes=1, max_epochs=args.epochs, callbacks=callbacks, profiler=profiler, + gradient_clip_val=1.0, ) t = default_timer() @@ -1030,7 +1175,7 @@ def train(args): print(f"Time elapsed: {(default_timer() - t) / 60:.02f} min") print(f"CPU Memory used: {(_process_memory() - memory) / 1e9:.2f} GB") print(f"GPU Memory used : {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") - + return locals() @@ -1045,7 +1190,8 @@ def parse_train_args(): "--config", is_config_file=True, help="config file path", - default="configs/vanvliet.yaml", + # default="configs/vanvliet.yaml", + default=str(Path("/home/achard/trackastra/scripts/example_config.yaml").resolve()), ) parser.add_argument("-o", "--outdir", type=str, default="runs") parser.add_argument("--name", type=str, help="Name to append to timestamp") @@ -1105,14 +1251,7 @@ def parse_train_args(): parser.add_argument( "--features", type=str, - choices=[ - "none", - "regionprops", - "regionprops2", - "patch", - "patch_regionprops", - "wrfeat", - ], + choices=list(CTCData.VALID_FEATURES), default="wrfeat", ) parser.add_argument( @@ -1196,6 +1335,76 @@ def parse_train_args(): " imbalance)" ), ) + # Pretrained feats + extra arguments + parser.add_argument( + "--pretrained_feats_model", + type=str, + choices=list(AVAILABLE_PRETRAINED_BACKBONES.keys()), + default=None, + help="If mode is pretrained_feats, specify the model to use for feature extraction", + ) + parser.add_argument( + "--pretrained_model_path", + type=str, + default=None, + help="Path to pretrained model to use for feature extraction. Only valid if features is pretrained_feats.", + ) + parser.add_argument( + "--pretrained_feats_mode", + type=str, + # choices=["nearest_patch", "mean_patches_bbox", "max_patches_bbox", "mean_patches_exact", "max_patches_exact"], + choices=list(PretrainedFeatsExtractionMode.__args__), + default=None, + help="If mode is pretrained_feats, specify the mode to use for feature extraction", + ) + parser.add_argument( + "--pretrained_feats_additional_props", + type=str, + choices=list(_PROPERTIES.keys()), + default=None, + help="Additional regionprops features to use in addition to pretrained model embeddings", + ) + # parser.add_argument( + # "--pretrained_feats_pca_ncomp", + # type=int, + # default=None, + # help="Number of components to use for PCA dimensionality reduction. If None, no PCA is applied.", + # ) + parser.add_argument( + "--weight_decay", + type=float, + default=0.01, + help="Weight decay for the AdamW optimizer", + ) + parser.add_argument( + "--pretrained_n_augs", + type=int, + default=None, + help="Number of augmentations to use for pretrained features. Only valid if features is pretrained_feats_aug", + ) + parser.add_argument( + "--disable_xy_coords", + type=str2bool, + default=False, + help="Disable x and y coordinates as input features. --features cannot be none if True.", + ) + parser.add_argument( + "--disable_all_coords", + type=str2bool, + default=False, + help="Disable all coordinates T(Z)XY as input features. --features cannot be none if True.", + ) + parser.add_argument( + "--rotate_features", + type=str2bool, + default=False, + help="Rotate features using augmented coordinates. features must be 'pretrained_feats' or 'pretrained_feats_aug' if True.", + ) + parser.add_argument( + "--reduced_pretrained_feat_dim", + type=int, + default=128, + ) args, unknown_args = parser.parse_known_args() diff --git a/scripts/wrfeat_aug_bacteria.yaml b/scripts/wrfeat_aug_bacteria.yaml new file mode 100644 index 0000000..5a01feb --- /dev/null +++ b/scripts/wrfeat_aug_bacteria.yaml @@ -0,0 +1,96 @@ +# Forked from zih_bacteria.yaml +name: vanvliet_wrfeat +# model: /home/achard/trastra_v2/BEST_2025-06-04_13-23-38_vanvliet_sam21_aug +epochs: 800 +warmup_epochs: 5 +window: 4 +attn_dist_mode: v1 +delta_cutoff: 1 +num_encoder_layers: 4 +num_decoder_layers: 4 +causal_norm: none +d_model: 256 +# dropout: 0.05 +pos_embed_per_dim: 32 +lr: 0.0001 +train_samples: 32000 +max_tokens: 2048 +batch_size: 64 +detection_folders: +- TRA +crop_size: +- 320 +- 320 +attn_positional_bias: rope +ndim: 2 +# Features config +# features: pretrained_feats_aug +# features: pretrained_feats +features: wrfeat +#### Additional parameters for pretrained_feats +augment: 3 +# pretrained_n_augs: 25 +# pretrained_feats_model: facebook/sam2.1-hiera-base-plus +# pretrained_feats_model: facebookresearch/co-tracker +# pretrained_feats_model: weigertlab/tarrow +# pretrained_model_path: /home/achard/tarrow_runs/vanvliet_backbone_unet_delta1-2 +# pretrained_feats_model: debug/encoded_labels +# reduced_pretrained_feat_dim: 128 +# pretrained_feats_mode: mean_patches_exact +# pretrained_feats_mode: nearest_patch +# pretrained_feats_additional_props: regionprops_small +# feat_embed_per_dim: 8 # Reduce additional dimensions added to the features via positional embedding +# rotate_features: true +# disable_xy_coords: true +# disable_all_coords: True +cachedir: /backup/achard/cache +outdir: /home/achard/trastra_v2 +# weight_decay: 0.01 +#### Logger +logger: wandb +wandb_project: "trackastra_v2" +# Paths config +cache: true +compress: true +distributed: false +### + +input_train: +- /backup/achard/CTC_DATA/vanvliet/cib/140409-03 # +- /backup/achard/CTC_DATA/vanvliet/cib/140415-08 +- /backup/achard/CTC_DATA/vanvliet/recA/151027-05 +- /backup/achard/CTC_DATA/vanvliet/recA/151027-06 +# - /backup/achard/CTC_DATA/vanvliet/recA/151027-10 # Has an empty frame, annoying +- /backup/achard/CTC_DATA/vanvliet/recA/151028-01 +- /backup/achard/CTC_DATA/vanvliet/recA/151029-05 +- /backup/achard/CTC_DATA/vanvliet/recA/151029-11 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-6 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E2-1 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E2-2 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E3-11 +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E4-12 +- /backup/achard/CTC_DATA/vanvliet/metA/150317-07 +- /backup/achard/CTC_DATA/vanvliet/metA/150318-06 +- /backup/achard/CTC_DATA/vanvliet/metA/150331-12 +- /backup/achard/CTC_DATA/vanvliet/metA/151222-10 +- /backup/achard/CTC_DATA/vanvliet/metA/151222-11 +- /backup/achard/CTC_DATA/vanvliet/pheA/150324-03 +- /backup/achard/CTC_DATA/vanvliet/pheA/150324-05 +- /backup/achard/CTC_DATA/vanvliet/pheA/150325-04 +- /backup/achard/CTC_DATA/vanvliet/pheA/160112-04 +- /backup/achard/CTC_DATA/vanvliet/trpL/150303-01 +- /backup/achard/CTC_DATA/vanvliet/trpL/150303-08 + +input_val: +- /backup/achard/CTC_DATA/vanvliet/trpL/150428-08 +- /backup/achard/CTC_DATA/vanvliet/trpL/151021-11 + +input_test: +# - data/CTC_DATA/vanvliet/cib/140408-01 # dense labels not precise +- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-1 # ok +- /backup/achard/CTC_DATA/vanvliet/rpsM/151029_E1-5 # ok +- /backup/achard/CTC_DATA/vanvliet/rpsM/151101_E3-12 # ok +- /backup/achard/CTC_DATA/vanvliet/trpL/150309-04 # ok +# - data/CTC_DATA/vanvliet/trpL/150310-05 # was empty +- /backup/achard/CTC_DATA/vanvliet/trpL/150310-11 # not in delta test set. quite some global moving! +- /backup/achard/CTC_DATA/vanvliet/pheA/160112-06 # ok. tricky. not in delta test set. diff --git a/setup.cfg b/setup.cfg index 8afd34c..e4b3d2b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,11 +19,13 @@ classifiers = [options] packages = find: install_requires = - numpy<2 + numpy matplotlib scipy + h5py pandas - numerize + dask + humanize configargparse tensorboard pyyaml @@ -34,6 +36,7 @@ install_requires = kornia>=0.7.0 # TODO remove old augs torch torchvision + # transformers lz4 imagecodecs>=2023.7.10 wandb @@ -55,6 +58,12 @@ dev = build test = pytest +pretrained_feats = + transformers + sam-2 @ git+https://github.com/facebookresearch/sam2.git + segment-anything @ git+https://github.com/facebookresearch/segment-anything.git + scikit-learn + zarr [options.entry_points] console_scripts = diff --git a/tests/test_attn.py b/tests/test_attn.py new file mode 100644 index 0000000..ef09f3e --- /dev/null +++ b/tests/test_attn.py @@ -0,0 +1,38 @@ +import torch +from trackastra.model.model_parts import RelativePositionalAttention + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +if __name__ == "__main__": + model = RelativePositionalAttention(coord_dim=2, + embed_dim=64, + n_head=2, + mode='rope', + attn_dist_mode='v2') + + model.eval() + model.to(device) + + B,N = 3,11 + q = torch.rand(B,N,64).to(device) + x = torch.rand(B,N,3).to(device) + + pad_mask = torch.zeros((B, N), dtype=torch.bool).to(device) + pad_mask[0,-2:] = True + pad_mask[1,-3:] = True + pad_mask[2,-4:] = True + + + + u = model(q,q,q,coords=x, padding_mask=pad_mask) + mask = model.attn_mask + + u1 = model(q[:1],q[:1],q[:1],coords=x[:1],padding_mask=pad_mask[:1]) + mask1 = model.attn_mask + + err = torch.abs(u[:1] - u1).mean() + print(f'Error: {err:.4f}') + print('close: ', torch.allclose(u[:1],u1, rtol=1e-3, atol=1e-6)) + + \ No newline at end of file diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index 153c485..89078fb 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -1,7 +1,7 @@ import numpy as np import pytest from scipy.ndimage import maximum_filter -from trackastra.data import AugmentationPipeline +from trackastra.data import AugmentationPipeline, CTCData def plot_augs(b1, b2): @@ -54,6 +54,8 @@ def test_augpipeline(plot=False): if __name__ == "__main__": + + test_augpipeline(plot=True) # pipe = RandomCrop((30, 40, 10, 20), ensure_inside_points=True) diff --git a/trackastra/data/__init__.py b/trackastra/data/__init__.py index fbddf06..be4c89c 100644 --- a/trackastra/data/__init__.py +++ b/trackastra/data/__init__.py @@ -14,5 +14,31 @@ BalancedDistributedSampler, ) from .example_data import example_data_bacteria, example_data_fluo_3d, example_data_hela -from .utils import filter_track_df, load_tiff_timeseries, load_tracklet_links +from .pretrained_augmentations import ( + PretrainedAugmentations, + PretrainedIntensityAugmentations, + PretrainedMovementAugmentations, +) +from .pretrained_features import ( + CellposeSAMFeatures, + CoTrackerFeatures, + DinoV2Features, + FeatureExtractor, + FeatureExtractorAugWrapper, + HieraFeatures, + MicroSAMFeatures, + PretrainedBackboneType, + PretrainedFeatsExtractionMode, + PretrainedFeatureExtractorConfig, + SAM2Features, + SAM2HighresFeatures, + SAMFeatures, + TAPFeatures, +) +from .utils import ( + filter_track_df, + load_tiff_timeseries, + load_tracklet_links, + make_hashable, +) from .wrfeat import WRFeatures, build_windows, get_features diff --git a/trackastra/data/data.py b/trackastra/data/data.py index 9df93fa..b1fde4c 100644 --- a/trackastra/data/data.py +++ b/trackastra/data/data.py @@ -1,8 +1,13 @@ +from __future__ import annotations + +import hashlib +import json import logging from collections.abc import Sequence +from functools import lru_cache from pathlib import Path from timeit import default_timer -from typing import Literal +from typing import TYPE_CHECKING, ClassVar, Literal import joblib import lz4.frame @@ -11,6 +16,7 @@ import pandas as pd import tifffile import torch +import zarr from numba import njit from scipy import ndimage as ndi from scipy.spatial.distance import cdist @@ -32,13 +38,22 @@ extract_features_regionprops, ) from trackastra.data.matching import matching +from trackastra.data.pretrained_augmentations import PretrainedAugmentations +from trackastra.utils import blockwise_sum, masks2properties, normalize -# from ..utils import blockwise_sum, normalize -from trackastra.utils import blockwise_sum, normalize +from .utils import make_hashable logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - +# logger.setLevel(logging.INFO) +logger.setLevel(logging.DEBUG) # FIXME go back to INFO for release + +if TYPE_CHECKING: + from trackastra.data.pretrained_features import ( + PretrainedBackboneType, + PretrainedFeatsExtractionMode, + PretrainedFeatureExtractorConfig, + ) + def _filter_track_df(df, start_frame, end_frame, downscale): """Only keep tracklets that are present in the given time interval.""" @@ -115,6 +130,53 @@ def wrapper(*args, **kwargs): class CTCData(Dataset): + """Cell Tracking Challenge data loader.""" + # Amount of feature per mode per dimension + FEATURES_DIMENSIONS: ClassVar = { + "regionprops": { + 2: 7, + 3: 11, + }, + "regionprops2": { + 2: 6, + 3: 11, + }, + "regionprops_full": { + 2: 9, + 3: 13, + }, + "patch": { + 2: 256, + 3: 256, + }, + "patch_regionprops": { + 2: 256 + 5, + 3: 256 + 8, + }, + "none": { + 2: 0, + 3: 0, + } + # "wrfeat" -> defined by wrfeat + # "pretrained_feats":{ # -> defined by PretrainedFeatureExtractorConfig.feat_dim + } + VALID_FEATURES: ClassVar = { + "none", + "regionprops", + "regionprops2", + "patch", + "patch_regionprops", + "wrfeat", + "pretrained_feats", + "pretrained_feats_aug", + } + + def __new__(cls, *args, **kwargs): + # Check if features is "pretrained_feats_aug"; if it is, use CTCDataAugPretrainedFeats class + if kwargs.get("features") == "pretrained_feats_aug": + return super().__new__(globals()["CTCDataAugPretrainedFeats"]) + return super().__new__(cls) + def __init__( self, root: str = "", @@ -134,40 +196,59 @@ def __init__( "patch", "patch_regionprops", "wrfeat", + "pretrained_feats", + "pretrained_feats_aug", ] = "wrfeat", sanity_dist: bool = False, crop_size: tuple | None = None, return_dense: bool = False, compress: bool = False, + pretrained_backbone_config: PretrainedFeatureExtractorConfig | None = None, + # pca_preprocessor: EmbeddingsPCACompression | None = None, + rotate_features: bool = False, + load_immediately: bool = True, **kwargs, ) -> None: - """_summary_. - - Args: - root (str): - Folder containing the CTC TRA folder. - ndim (int): - Number of dimensions of the data. Defaults to 2d - (if ndim=3 and data is two dimensional, it will be cast to 3D) - detection_folders: - List of relative paths to folder with detections. - Defaults to ["TRA"], which uses the ground truth detections. - window_size (int): - Window size for transformer. - slice_pct (tuple): - Slice the dataset by percentages (from, to). - augment (int): - if 0, no data augmentation. if > 0, defines level of data augmentation. - features (str): - Types of features to use. - sanity_dist (bool): - Use euclidian distance instead of the association matrix as a target. - crop_size (tuple): - Size of the crops to use for augmentation. If None, no cropping is used. - return_dense (bool): - Return dense masks and images in the data samples. - compress (bool): - Compress elements/remove img if not needed to save memory for large datasets + """Args: + root (str): + Folder containing the CTC TRA folder. + ndim (int): + Number of dimensions of the data. Defaults to 2d + (if ndim=3 and data is two dimensional, it will be cast to 3D) + detection_folders: + List of relative paths to folder with detections. + Defaults to ["TRA"], which uses the ground truth detections. + window_size (int): + Window size for transformer. + slice_pct (tuple): + Slice the dataset by percentages (from, to). + augment (int): + if 0, no data augmentation. if > 0, defines level of data augmentation. + features (str): + Types of features to use. + sanity_dist (bool): + Use euclidian distance instead of the association matrix as a target. + crop_size (tuple): + Size of the crops to use for augmentation. If None, no cropping is used. + return_dense (bool): + Return dense masks and images in the data samples. + compress (bool): + Compress elements/remove img if not needed to save memory for large datasets + pretrained_backbone_config (PretrainedFeatureExtractorConfig): + Configuration for the pretrained backbone. + If mode is set to "pretrained_feats", this configuration is used to extract features. + Ignored otherwise. + rotate_features (bool): + Apply rotation to features based on (augmented) coordinates. + Only valid if used with "pretrained_feats" or "pretrained_feats_aug" mode. + load_immediately (bool): + If True, load the data immediately. If False, load the data lazily. + If False, you need to call `start_loading()` to load the data. + Defaults to True. + # pca_preprocessor (EmbeddingsPCACompression): + # PCA preprocessor for the pretrained features. + # If mode is set to "pretrained_feats", this is used to reduce the dimensionality of the features. + # Ignored otherwise. """ super().__init__() @@ -180,14 +261,29 @@ def __init__( self.downscale_spatial = downscale_spatial self.downscale_temporal = downscale_temporal self.detection_folders = detection_folders - self.ndim = ndim + self._ndim = ndim self.features = features + self.rotate_feats = rotate_features - if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]: - raise ValueError( - f"'{features}' not one of the supported {ndim}D features" - f" {tuple(_PROPERTIES[ndim].keys())}" - ) + if features not in self.VALID_FEATURES: + if features not in _PROPERTIES[self._ndim] and features != "wrfeat": + raise ValueError( + f"'{features}' not one of the supported {self._ndim}D features" + f" {tuple(_PROPERTIES[self._ndim].keys())}" + ) + + if features == "pretrained_feats" or features == "pretrained_feats_aug": + try: + if TYPE_CHECKING: + import transformers + transformers.__version__ + except ImportError as e: + msg = """Please install pretrained_feats extra requirements to use pretrained features mode.\n + Run : + pip install trackastra[pretrained_feats] + to install the required dependencies. + """ + raise ImportError(msg) from e logger.info(f"ROOT (config): \t{self.root}") self.root, self.gt_tra_folder = self._guess_root_and_gt_tra_folder(self.root) @@ -209,9 +305,18 @@ def __init__( self.img_folder = self._guess_img_folder(self.root) logger.info(f"IMG (guessed):\t{self.img_folder}") - self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs( - ndim, features, augment, crop_size - ) + self._pretrained_config = None + if features == "pretrained_feats" or features == "pretrained_feats_aug": + if pretrained_backbone_config is None: + raise ValueError("Pretrained backbone config must be provided for pretrained features mode.") + self.pretrained_config = pretrained_backbone_config + if self.pretrained_config.save_path is None: + self.pretrained_config.save_path = self.img_folder + self.FEATURES_DIMENSIONS["pretrained_feats"] = self.pretrained_config.feat_dim + + self.augment_level = augment + self.crop_size = crop_size + self.augmenter, self.cropper = self._setup_features_augs() if window_size <= 1: raise ValueError("window must be >1") @@ -224,35 +329,139 @@ def __init__( self.compress = compress self.start_frame = 0 self.end_frame = None - + + # Pretrained model attributes for feature extraction if specified + self._pretrained_model_input_size_factor = 1 + self.feature_extractor_input_size = None + self.feature_extractor_save_path = None + self.pretrained_features = None + self.feature_extractor = None + self.pretrained_feature_augmenter = None + # self.pca_preprocessor = pca_preprocessor + + if load_immediately: + self.start_loading() + + if kwargs: + logger.warning(f"Unused kwargs: {kwargs}") + + def start_loading(self): start = default_timer() - if self.features == "wrfeat": - self.windows = self._load_wrfeat() - else: - self.windows = self._load() + self._init_features() # loads and creates windows self.n_divs = self._get_ndivs(self.windows) - if len(self.windows) > 0: - self.ndim = self.windows[0]["coords"].shape[1] - self.n_objects = tuple(len(t["coords"]) for t in self.windows) - logger.info( - f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track" - f" windows from {self.root} ({default_timer() - start:.1f}s)\n" - ) - else: - self.n_objects = 0 - logger.warning(f"Could not load any tracks from {self.root}") + # if len(self.windows) > 0: + # self.ndim = self.windows[0]["coords"].shape[1] + # self.n_objects = tuple(len(t["coords"]) for t in self.windows) + # logger.info( + # f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track" + # f" windows from {self.root} ({default_timer() - start:.1f}s)\n" + # ) + # else: + # self.n_objects = 0 + # logger.warning(f"Could not load any tracks from {self.root}") + self._get_ndim_and_nobj(start) if self.compress: self._compress_data() - # def from_ctc - + def _init_features(self): + if self.features == "wrfeat" or self.features == "pretrained_feats": + self.windows = self._load_wrfeat() + else: + self.windows = self._load() + + @property + def config(self): + return { + "root": str(self.root), + "ndim": self.ndim, + "use_gt": self.use_gt, + "detection_folders": self.detection_folders, + "window_size": self.window_size, + "max_tokens": self.max_tokens, + "slice_pct": self.slice_pct, + "downscale_spatial": self.downscale_spatial, + "downscale_temporal": self.downscale_temporal, + "augment": self.augment_level, + "features": self.features, + "sanity_dist": self.sanity_dist, + "crop_size": self.crop_size, + "return_dense": self.return_dense, + "compress": self.compress, + "pretrained_config": ( + self.pretrained_config.to_dict() if self.pretrained_config else None + ), + "rotate_features": self.rotate_feats, + } + + @property + def config_hash(self): + """Returns a hash of the configuration.""" + cfg = make_hashable(self.config) + config_str = json.dumps(cfg, sort_keys=True) + return hashlib.sha256(config_str.encode()).hexdigest() + + @property + def ndim(self): + return self._ndim + + @ndim.setter + def ndim(self, value: int): + if value not in (2, 3): + raise ValueError(f"ndim must be 2 or 3, got {value}") + self._ndim = value + + @property + def feat_dim(self): + if self.pretrained_config is None: + return self.FEATURES_DIMENSIONS[self.features][self.ndim] + elif self.features == "wrfeat": + return wrfeat.WRFeatures.PROPERTIES_DIMS[ + wrfeat.DEFAULT_PROPERTIES + ][self.ndim] + else: + return self.pretrained_config.additional_feat_dim + + @property + def pretrained_config(self): + return self._pretrained_config + + @property + def pretrained_feat_dim(self): + if self._pretrained_config is None: + return 0 + return self._pretrained_config.feat_dim + + @pretrained_config.setter + def pretrained_config(self, config: PretrainedFeatureExtractorConfig): + if isinstance(config, dict): + from trackastra.data.pretrained_features import ( + PretrainedFeatureExtractorConfig, + ) + config = PretrainedFeatureExtractorConfig.from_dict(config) + self.update_pretrained_feat_dim(config) + self._pretrained_config = config + + def update_pretrained_feat_dim(self, config): + try: + self.FEATURES_DIMENSIONS["pretrained_feats"] = config.feat_dim + except AttributeError as e: + if isinstance(config, dict): + self.FEATURES_DIMENSIONS["pretrained_feats"] = config["feat_dim"] + else: + raise e + + @staticmethod + def get_feat_dim(features, ndim, ): + return CTCData.FEATURES_DIMENSIONS[features][ndim] + @classmethod def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict): self = cls(**train_args) + # def from_ctc # for key, value in train_args.items(): # setattr(self, key, value) @@ -291,13 +500,11 @@ def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict): # self.img_folder = self._guess_img_folder(self.root) # logger.info(f"IMG:\t\t{self.img_folder}") - self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs( - self.ndim, self.features, self.augment, self.crop_size - ) + self.augmenter, self.cropper = self._setup_features_augs() start = default_timer() - if self.features == "wrfeat": + if self.features == "wrfeat" or self.features == "pretrained_feats": self.windows = self._load_wrfeat() else: self.windows = self._load() @@ -317,7 +524,19 @@ def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict): if self.compress: self._compress_data() - + + def _get_ndim_and_nobj(self, start): + if len(self.windows) > 0: + self.ndim = self.windows[0]["coords"].shape[1] + self.n_objects = tuple(len(t["coords"]) for t in self.windows) + logger.info( + f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track" + f" windows from {self.root} ({default_timer() - start:.1f}s)\n" + ) + else: + self.n_objects = 0 + logger.warning(f"Could not load any tracks from {self.root}") + def _get_ndivs(self, windows): n_divs = [] for w in tqdm(windows, desc="Counting divisions", leave=False): @@ -336,59 +555,47 @@ def _get_ndivs(self, windows): return n_divs def _setup_features_augs( - self, ndim: int, features: str, augment: int, crop_size: tuple[int] + self ): - if self.features == "wrfeat": - return self._setup_features_augs_wrfeat(ndim, features, augment, crop_size) + if self.features in ["wrfeat", "pretrained_feats"]: + return self._setup_features_augs_wrfeat() cropper = ( RandomCrop( - crop_size=crop_size, - ndim=ndim, + crop_size=self.crop_size, + ndim=self.ndim, use_padding=False, ensure_inside_points=True, ) - if crop_size is not None + if self.crop_size is not None else None ) # Hack if self.features == "none": - return 0, default_augmenter, cropper - - if ndim == 2: - augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None - feat_dim = { - "none": 0, - "regionprops": 7, - "regionprops2": 6, - "patch": 256, - "patch_regionprops": 256 + 5, - }[features] - elif ndim == 3: - augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None - feat_dim = { - "none": 0, - "regionprops2": 11, - "patch_regionprops": 256 + 8, - }[features] - - return feat_dim, augmenter, cropper + return default_augmenter, cropper + + augmenter = AugmentationPipeline(p=0.8, level=self.augment_level) if self.augment_level else None + + return augmenter, cropper def _compress_data(self): # compress masks and assoc_matrix logger.info("Compressing masks and assoc_matrix to save memory") for w in self.windows: - w["mask"] = _CompressedArray(w["mask"]) + if "mask" in w: + w["mask"] = _CompressedArray(w["mask"]) # dont compress full imgs (as needed for patch features) - w["img"] = _CompressedArray(w["img"]) + if "img" in w: + w["img"] = _CompressedArray(w["img"]) w["assoc_matrix"] = _CompressedArray(w["assoc_matrix"]) self.gt_masks = _CompressedArray(self.gt_masks) self.det_masks = {k: _CompressedArray(v) for k, v in self.det_masks.items()} # dont compress full imgs (as needed for patch features) self.imgs = _CompressedArray(self.imgs) - def _guess_root_and_gt_tra_folder(self, inp: Path): + @staticmethod + def _guess_root_and_gt_tra_folder(inp: Path): """Guesses the root and the ground truth folder from a given input path. Args: @@ -410,15 +617,17 @@ def _guess_root_and_gt_tra_folder(self, inp: Path): tra = ctc_tra if ctc_tra.exists() else inp / "TRA" # 01 --> 01, 01_GT/TRA or 01/TRA return inp, tra - - def _guess_img_folder(self, root: Path): + + @staticmethod + def _guess_img_folder(root: Path): """Guesses the image folder corresponding to a root.""" if (root / "img").exists(): return root / "img" else: return root - def _guess_mask_folder(self, root: Path, gt_tra: Path): + @staticmethod + def _guess_mask_folder(root: Path, gt_tra: Path): """Guesses the mask folder corresponding to a root. In CTC format, we use silver truth segmentation masks. @@ -452,7 +661,7 @@ def _guess_det_folder(cls, root: Path, suffix: str): def __len__(self): return len(self.windows) - + def _load_gt(self): logger.info("Loading ground truth") self.start_frame = int( @@ -539,7 +748,7 @@ def _load_tiffs(self, folder: Path, dtype=None): logger.debug(f"Loaded array of shape {x.shape} from {folder}") return x - def _masks2properties(self, masks): + def _masks2properties(self, masks, return_props_by_time=False): """Turn label masks into lists of properties, sorted (ascending) by time and label id. Args: @@ -550,38 +759,7 @@ def _masks2properties(self, masks): ts: List of timepoints coords: List of coordinates """ - # Get coordinates, timepoints, and labels of detections - labels = [] - ts = [] - coords = [] - properties_by_time = dict() - assert len(self.imgs) == len(masks) - for _t, frame in tqdm( - enumerate(masks), - # total=len(detections), - leave=False, - desc="Loading masks and properties", - ): - regions = regionprops(frame) - t_labels = [] - t_ts = [] - t_coords = [] - for _r in regions: - t_labels.append(_r.label) - t_ts.append(_t) - centroid = np.array(_r.centroid).astype(int) - t_coords.append(centroid) - - properties_by_time[_t] = dict(coords=t_coords, labels=t_labels) - labels.extend(t_labels) - ts.extend(t_ts) - coords.extend(t_coords) - - labels = np.array(labels, dtype=int) - ts = np.array(ts, dtype=int) - coords = np.array(coords, dtype=int) - - return labels, ts, coords, properties_by_time + return masks2properties(self.imgs, masks, return_props_by_time=return_props_by_time) def _load_tracklet_links(self, folder: Path) -> pd.DataFrame: df = pd.read_csv( @@ -619,16 +797,17 @@ def _check_dimensions(self, x: np.ndarray): raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data") return x - def _load(self): + def _prepare_masks_and_imgs(self, return_orig_imgs=False): # Load ground truth - logger.info("Loading ground truth") self.gt_masks, self.gt_track_df = self._load_gt() - self.gt_masks = self._check_dimensions(self.gt_masks) # Load images if self.img_folder is None: - self.imgs = np.zeros_like(self.gt_masks) + if self.gt_masks is not None: + self.imgs = np.zeros_like(self.gt_masks) + else: + raise NotImplementedError("No images and no GT masks") else: logger.info("Loading images") imgs = self._load_tiffs(self.img_folder, dtype=np.float32) @@ -644,8 +823,38 @@ def _load(self): for im, mask in zip(self.imgs, self.gt_masks) ] ) + if np.any(np.isnan(self.imgs)): + raise ValueError("Compressed images contain NaN values") assert len(self.gt_masks) == len(self.imgs) + if return_orig_imgs: + return imgs + + def _load(self): + # # Load ground truth + # logger.info("Loading ground truth") + # self.gt_masks, self.gt_track_df = self._load_gt() + # self.gt_masks = self._check_dimensions(self.gt_masks) + + # # Load images + # if self.img_folder is None: + # self.imgs = np.zeros_like(self.gt_masks) + # else: + # logger.info("Loading images") + # imgs = self._load_tiffs(self.img_folder, dtype=np.float32) + # self.imgs = np.stack( + # [normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)] + # ) + # self.imgs = self._check_dimensions(self.imgs) + # if self.compress: + # # prepare images to be compressed later (e.g. removing non masked parts for regionprops features) + # self.imgs = np.stack( + # [ + # _compress_img_mask_preproc(im, mask, self.features) + # for im, mask in zip(self.imgs, self.gt_masks) + # ] + # ) + self._prepare_masks_and_imgs() # Load each of the detection folders and create data samples with a sliding window windows = [] @@ -662,7 +871,7 @@ def _load(self): det_ts, det_coords, det_properties_by_time, - ) = self._masks2properties(det_masks) + ) = self._masks2properties(det_masks, return_props_by_time=True) det_gt_matching = { t: {_l: _l for _l in det_properties_by_time[t]["labels"]} @@ -684,7 +893,7 @@ def _load(self): det_ts, det_coords, det_properties_by_time, - ) = self._masks2properties(det_masks) + ) = self._masks2properties(det_masks, return_props_by_time=True) # FIXME matching can be slow for big images # raise NotImplementedError("Matching not implemented for 3d version") @@ -703,6 +912,8 @@ def _load(self): self.properties_by_time[_f] = det_properties_by_time self.det_masks[_f] = det_masks + + # Build windows _w = self._build_windows( det_folder, det_masks, @@ -804,10 +1015,32 @@ def _build_windows( logger.debug(f"Built {len(windows)} track windows from {det_folder}.\n") return windows - + + def _apply_transform_and_check(self, img, labels, mask, coords, timepoints, min_time, assoc_matrix): + (img2, mask2, coords2), idx = self.augmenter( + img, mask, coords, timepoints - min_time + ) + if len(idx) > 0: + img, mask, coords = img2, mask2, coords2 + labels = labels[idx] + timepoints = timepoints[idx] + assoc_matrix = assoc_matrix[idx][:, idx] + mask = mask.astype(int) + else: + logger.debug( + "Disable augmentation as no trajectories would be left" + ) + return img, labels, mask, coords, timepoints, assoc_matrix + + @staticmethod + def decompress(data): + if isinstance(data, _CompressedArray): + return data.decompress() + return data + def __getitem__(self, n: int, return_dense=None): # if not set, use default - if self.features == "wrfeat": + if self.features == "wrfeat" or self.features == "pretrained_feats": return self._getitem_wrfeat(n, return_dense) if return_dense is None: @@ -822,12 +1055,15 @@ def __getitem__(self, n: int, return_dense=None): timepoints = track["timepoints"] min_time = track["t1"] - if isinstance(mask, _CompressedArray): - mask = mask.decompress() - if isinstance(img, _CompressedArray): - img = img.decompress() - if isinstance(assoc_matrix, _CompressedArray): - assoc_matrix = assoc_matrix.decompress() + # if isinstance(mask, _CompressedArray): + # mask = mask.decompress() + # if isinstance(img, _CompressedArray): + # img = img.decompress() + # if isinstance(assoc_matrix, _CompressedArray): + # assoc_matrix = assoc_matrix.decompress() + mask = CTCData.decompress(mask) + img = CTCData.decompress(img) + assoc_matrix = CTCData.decompress(assoc_matrix) # cropping if self.cropper is not None: @@ -853,19 +1089,9 @@ def __getitem__(self, n: int, return_dense=None): elif self.features in ("regionprops", "regionprops2"): if self.augmenter is not None: - (img2, mask2, coords2), idx = self.augmenter( - img, mask, coords, timepoints - min_time + img, labels, mask, coords, timepoints, assoc_matrix = self._apply_transform_and_check( + img, labels, mask, coords, timepoints, min_time, assoc_matrix ) - if len(idx) > 0: - img, mask, coords = img2, mask2, coords2 - labels = labels[idx] - timepoints = timepoints[idx] - assoc_matrix = assoc_matrix[idx][:, idx] - mask = mask.astype(int) - else: - logger.debug( - "disable augmentation as no trajectories would be left" - ) features = tuple( extract_features_regionprops( @@ -875,21 +1101,11 @@ def __getitem__(self, n: int, return_dense=None): ) features = np.concatenate(features, axis=0) # features = np.zeros((len(coords), self.feat_dim)) - elif self.features == "patch": if self.augmenter is not None: - (img2, mask2, coords2), idx = self.augmenter( - img, mask, coords, timepoints - min_time + img, labels, mask, coords, timepoints, assoc_matrix = self._apply_transform_and_check( + img, labels, mask, coords, timepoints, min_time, assoc_matrix ) - if len(idx) > 0: - img, mask, coords = img2, mask2, coords2 - labels = labels[idx] - timepoints = timepoints[idx] - assoc_matrix = assoc_matrix[idx][:, idx] - mask = mask.astype(int) - else: - print("disable augmentation as no trajectories would be left") - features = tuple( extract_features_patch( m, @@ -902,18 +1118,9 @@ def __getitem__(self, n: int, return_dense=None): features = np.concatenate(features, axis=0) elif self.features == "patch_regionprops": if self.augmenter is not None: - (img2, mask2, coords2), idx = self.augmenter( - img, mask, coords, timepoints - min_time + img, labels, mask, coords, timepoints, assoc_matrix = self._apply_transform_and_check( + img, labels, mask, coords, timepoints, min_time, assoc_matrix ) - if len(idx) > 0: - img, mask, coords = img2, mask2, coords2 - labels = labels[idx] - timepoints = timepoints[idx] - assoc_matrix = assoc_matrix[idx][:, idx] - mask = mask.astype(int) - else: - print("disable augmentation as no trajectories would be left") - features1 = tuple( extract_features_patch( m, @@ -939,6 +1146,25 @@ def __getitem__(self, n: int, return_dense=None): ) features = np.concatenate(features, axis=0) + # MOVED as WRFeat. See wrfeat.WRPretrainedFeatures + # elif self.features == "pretrained_feats": + # if self.augmenter is not None: + # img, labels, mask, coords, timepoints, assoc_matrix = self._apply_transform_and_check( + # img, labels, mask, coords, timepoints, min_time, assoc_matrix + # ) + # ts, n_obj = np.unique(timepoints, return_counts=True) + # features = torch.zeros((n_obj.sum(), self.feat_dim)) + # offset = 0 + + # for t, count in zip(ts, n_obj): + # feat = self.pretrained_features[t] # (timepoints -> (n_regions, n_features)) for a SINGLE timepoint + # if feat.shape[0] != count: + # raise ValueError(f"Feature mismatch at time {t}: expected {count}, got {feat.shape[0]}") + # features[offset:offset + count] = feat + # offset += count + + # if features.shape[0] != len(timepoints): + # raise ValueError(f"Pretrained features shape mismatch: {features.shape[0]} != {len(timepoints)}") # remove temporal offset and add timepoints to coords relative_timepoints = timepoints - track["t1"] @@ -957,14 +1183,14 @@ def __getitem__(self, n: int, return_dense=None): ) coords0 = torch.from_numpy(coords).float() - features = torch.from_numpy(features).float() + features = torch.from_numpy(features).float() if isinstance(features, np.ndarray) else features.float() assoc_matrix = torch.from_numpy(assoc_matrix.copy()).float() labels = torch.from_numpy(labels).long() timepoints = torch.from_numpy(timepoints).long() if self.augmenter is not None: coords = coords0.clone() - coords[:, 1:] += torch.randint(0, 256, (1, self.ndim)) + coords[:, 1:] += torch.randint(-1024, 1024, (1, self.ndim)) else: coords = coords0.clone() res = dict( @@ -983,92 +1209,51 @@ def __getitem__(self, n: int, return_dense=None): mask = torch.from_numpy(mask.astype(int)).long() res["mask"] = mask - return res # wrfeat functions... # TODO: refactor this as a subclass or make everything a class factory. *very* hacky this way + # -> updated _setup_features_augs_wrfeat to use a factory instead def _setup_features_augs_wrfeat( - self, ndim: int, features: str, augment: int, crop_size: tuple[int] + self ): - # FIXME: hardcoded - feat_dim = 7 if ndim == 2 else 12 - if augment == 1: - augmenter = wrfeat.WRAugmentationPipeline( - [ - wrfeat.WRRandomFlip(p=0.5), - wrfeat.WRRandomAffine( - p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1) - ), - # wrfeat.WRRandomBrightness(p=0.8, factor=(0.5, 2.0)), - # wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)), - ] - ) - elif augment == 2: - augmenter = wrfeat.WRAugmentationPipeline( - [ - wrfeat.WRRandomFlip(p=0.5), - wrfeat.WRRandomAffine( - p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1) - ), - wrfeat.WRRandomBrightness(p=0.8), - wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)), - ] - ) - elif augment == 3: - augmenter = wrfeat.WRAugmentationPipeline( - [ - wrfeat.WRRandomFlip(p=0.5), - wrfeat.WRRandomAffine( - p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1) - ), - wrfeat.WRRandomBrightness(p=0.8), - wrfeat.WRRandomMovement(offset=(-10, 10), p=0.3), - wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)), - ] - ) - else: - augmenter = None + augmenter = wrfeat.AugmentationFactory.create_augmentation_pipeline(self.augment_level) + cropper = wrfeat.AugmentationFactory.create_cropper(self.crop_size, self.ndim) if self.crop_size is not None else None - cropper = ( - wrfeat.WRRandomCrop( - crop_size=crop_size, - ndim=ndim, - ) - if crop_size is not None - else None - ) - return feat_dim, augmenter, cropper + return augmenter, cropper def _load_wrfeat(self): - # Load ground truth - self.gt_masks, self.gt_track_df = self._load_gt() - self.gt_masks = self._check_dimensions(self.gt_masks) - - # Load images - if self.img_folder is None: - if self.gt_masks is not None: - self.imgs = np.zeros_like(self.gt_masks) - else: - raise NotImplementedError("No images and no GT masks") - else: - logger.info("Loading images") - imgs = self._load_tiffs(self.img_folder, dtype=np.float32) - self.imgs = np.stack( - [normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)] - ) - self.imgs = self._check_dimensions(self.imgs) - if self.compress: - # prepare images to be compressed later (e.g. removing non masked parts for regionprops features) - self.imgs = np.stack( - [ - _compress_img_mask_preproc(im, mask, self.features) - for im, mask in zip(self.imgs, self.gt_masks) - ] - ) - - assert len(self.gt_masks) == len(self.imgs) + # # Load ground truth + # self.gt_masks, self.gt_track_df = self._load_gt() + # self.gt_masks = self._check_dimensions(self.gt_masks) + + # # Load images + # if self.img_folder is None: + # if self.gt_masks is not None: + # self.imgs = np.zeros_like(self.gt_masks) + # else: + # raise NotImplementedError("No images and no GT masks") + # else: + # logger.info("Loading images") + # imgs = self._load_tiffs(self.img_folder, dtype=np.float32) + # self.imgs = np.stack( + # [normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)] + # ) + # self.imgs = self._check_dimensions(self.imgs) + # if self.compress: + # # prepare images to be compressed later (e.g. removing non masked parts for regionprops features) + # self.imgs = np.stack( + # [ + # _compress_img_mask_preproc(im, mask, self.features) + # for im, mask in zip(self.imgs, self.gt_masks) + # ] + # ) + # if np.any(np.isnan(self.imgs)): + # raise ValueError("Compressed images contain NaN values") + + # assert len(self.gt_masks) == len(self.imgs) + imgs = self._prepare_masks_and_imgs(return_orig_imgs=True) # Load each of the detection folders and create data samples with a sliding window windows = [] @@ -1114,13 +1299,42 @@ def _load_wrfeat(self): self.det_masks[_f] = det_masks # build features - - features = joblib.Parallel(n_jobs=8)( - joblib.delayed(wrfeat.WRFeatures.from_mask_img)( - mask=mask[None], img=img[None], t_start=t + if self.features == "pretrained_feats": + self._setup_pretrained_feature_extractor() + if np.all(self.imgs == 0): + raise ValueError("Images are empty. Images must be provided when using pretrained features") + self.feature_extractor.precompute_image_embeddings(imgs) # use NON_NORMALIZED images for pretrained features + # normalization is performed in the feature extractor + features = [ + wrfeat.WRPretrainedFeatures.from_mask_img( + img=img[np.newaxis], + mask=mask[np.newaxis], + feature_extractor=self.feature_extractor, + t_start=t, + additional_properties=self.pretrained_config.additional_features + ) + for t, (mask, img) in enumerate(zip(det_masks, self.imgs)) + ] + for wrf in features: + feats = wrf.features_stacked + if feats is not None and np.any(np.isnan(wrf.features_stacked)): + raise ValueError("NaN in features") + if torch.cuda.is_available(): + self.feature_extractor.embeddings = self.feature_extractor.embeddings.cpu() + torch.cuda.empty_cache() + elif self.features == "wrfeat": + features = joblib.Parallel(n_jobs=8)( + joblib.delayed(wrfeat.WRFeatures.from_mask_img)( + mask=mask[None], img=img[None], t_start=t + ) + for t, (mask, img) in enumerate(zip(det_masks, self.imgs)) ) - for t, (mask, img) in enumerate(zip(det_masks, self.imgs)) - ) + # features = [] + # for t, (mask, img) in enumerate(zip(det_masks, self.imgs)): + # wrf = wrfeat.WRFeatures.from_mask_img( + # mask=mask[None], img=img[None], t_start=t + # ) + # features.append(wrf) properties_by_time = dict() for _t, _feats in enumerate(features): @@ -1228,26 +1442,39 @@ def _getitem_wrfeat(self, n: int, return_dense=None): labels = labels[idx] timepoints = timepoints[idx] assoc_matrix = assoc_matrix[idx][:, idx] - else: - logger.debug("Skipping cropping") + # else: + # logger.debug("Skipping cropping") if self.augmenter is not None: feat = self.augmenter(feat) - + coords0 = np.concatenate((feat.timepoints[:, None], feat.coords), axis=-1) coords0 = torch.from_numpy(coords0).float() assoc_matrix = torch.from_numpy(assoc_matrix.astype(np.float32)) - features = torch.from_numpy(feat.features_stacked).float() + # if self.pca_preprocessor is not None: + # features = self.pca_preprocessor.transform(feat.features_stacked) + # else: + features = feat.features_stacked + if features is not None: + features = torch.from_numpy(features).float() + labels = torch.from_numpy(feat.labels).long() timepoints = torch.from_numpy(feat.timepoints).long() - + + pretrained_features = feat.pretrained_feats + if pretrained_features is not None: + pretrained_features = torch.from_numpy(pretrained_features).float() + if self.max_tokens and len(timepoints) > self.max_tokens: time_incs = np.where(timepoints - np.roll(timepoints, 1))[0] n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1] timepoints = timepoints[:n_elems] labels = labels[:n_elems] coords0 = coords0[:n_elems] - features = features[:n_elems] + if features is not None: + features = features[:n_elems] + if pretrained_features is not None: + pretrained_features = pretrained_features[:n_elems] assoc_matrix = assoc_matrix[:n_elems, :n_elems] logger.debug( f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}" @@ -1258,15 +1485,28 @@ def _getitem_wrfeat(self, n: int, return_dense=None): coords[:, 1:] += torch.randint(0, 512, (1, self.ndim)) else: coords = coords0.clone() + + if self.features == "pretrained_feats" and self.rotate_feats: + if isinstance(img, _CompressedArray): + image_shape = img._shape + else: + image_shape = img.shape + # logger.debug(f"Rotating pretrained features with shape {pretrained_features.shape} for image shape {image_shape}") + pretrained_features = CTCData.rotate_features( + pretrained_features, coords, image_shape, + n_rot_dims=self.pretrained_feat_dim, + ) + res = dict( features=features, + pretrained_features=pretrained_features, coords0=coords0, coords=coords, assoc_matrix=assoc_matrix, timepoints=timepoints, labels=labels, ) - + if return_dense: if all([x is not None for x in img]): img = torch.from_numpy(img).float() @@ -1274,7 +1514,809 @@ def _getitem_wrfeat(self, n: int, return_dense=None): mask = torch.from_numpy(mask.astype(int)).long() res["mask"] = mask + + if features is not None: + if torch.any(torch.isnan(features)): + raise ValueError("NaN in features") + elif torch.any(torch.all(features == 0, dim=-1)): + raise ValueError("Empty features") + + if pretrained_features is not None: + if torch.any(torch.isnan(pretrained_features)): + raise ValueError("NaN in pretrained features") + elif torch.any(torch.all(pretrained_features == 0, dim=-1)): + raise ValueError("Empty pretrained features") + + return res + + def _get_pretrained_features_save_path(self): + if self.pretrained_config is not None: + img_folder_name = "_".join(self.root.parts[-3:]) if len(self.root.parts) >= 3 else "_".join(self.root.parts) + img_folder_name = str(img_folder_name).replace(".", "").replace("/", "_").replace("\\", "_").replace(" ", "_") + return self.pretrained_config.save_path / f"embeddings/{img_folder_name}" + + def _setup_pretrained_feature_extractor(self): + if self.ndim == 3: + raise ValueError("Pretrained model feature extraction is not implemented for 3D data") + img_shape = self.imgs.shape[-2:] # initial guess, replaced later if shape changes + from trackastra.data.pretrained_features import ( + FeatureExtractor, + ) + self.feature_extractor_save_path = self._get_pretrained_features_save_path() + # self.feature_extractor = FeatureExtractor.from_model_name( + # self.pretrained_config.model_name, + # img_shape, + # save_path=self.feature_extractor_save_path, + # mode=self.pretrained_config.mode, + # device=self.pretrained_config.device, + # additional_features=self.pretrained_config.additional_features, + # ) + self.feature_extractor = FeatureExtractor.from_config( + self.pretrained_config, + image_shape=img_shape, + save_path=self.feature_extractor_save_path, + ) + self.feature_extractor_input_size = self.feature_extractor.input_size + + def _compute_pretrained_model_features(self): + if self.pretrained_config.model_name is None: + logger.warning("No pretrained model set, feature extraction not run") + return + + try: + self.feature_extractor.input_mul = self._pretrained_model_input_size_factor + except Exception: + logger.warning(f"Cannot change input size for pretrained model: {self.pretrained_config.model_name}") + self.pretrained_features = self.feature_extractor.precompute_region_embeddings(self.imgs) + # dict(n_frames) : torch.Tensor(n_regions_in_frame, n_features) + self.feature_extractor = None + + def compute_pretrained_features(self, input_size_factor: int | None = None, model: PretrainedBackboneType = None, mode: PretrainedFeatsExtractionMode = "nearest_patch"): + """Compute pretrained features for the dataset, if the model. input size factor or mode was changed. + + Args: + input_size_factor (int, optional): The input size factor for the pretrained model. Defaults to None. + model (PretrainedBackboneType, optional): The pretrained model to use. Defaults to None. + mode (PretrainedFeatsExtractionMode, optional): The mode to use for feature extraction. Defaults to "nearest_patch". + """ + if input_size_factor is not None: + self._pretrained_model_input_size_factor = input_size_factor + logger.debug(f"Setting input size factor to {input_size_factor}") + if model is not None: + self.pretrained_config.model_name = model + logger.debug(f"Setting pretrained model to {model}") + if mode is not None: + self.pretrained_config.mode = mode + logger.debug(f"Setting feature extraction mode to {mode}") + if input_size_factor is None and model is None and mode is None: + logger.warning("No changes in input size factor, model or mode. Skipping feature extraction.") + return + else: + self._compute_pretrained_model_features() + + @staticmethod + def rotate_features( + features: torch.Tensor, + coords: torch.Tensor, + image_shape: tuple, + n_rot_dims: int | None = None, + skip_first: int = 0, + ) -> torch.Tensor: + """Applies a RoPE-style rotation to each feature vector based on the object's centroid. + + Args: + features: (n_objects, hidden_state_size) tensor of features. + coords: (n_objects, 2) tensor of coordinates. + image_shape: (time, height, width) shape of the input image. + n_rot_dims: Number of feature dimensions to apply rotation to (must be even). If None, rotate all. + skip_first: Number of dimensions to skip at the beginning of the feature vector. No effect if 0. + + Returns: + Rotated features: (n_objects, hidden_state_size) + """ + import math + N, D = features.shape + assert skip_first < n_rot_dims, "skip_first must be less than n_rot_dims." + if n_rot_dims is None: + n_rot_dims = D + if skip_first != 0: + n_rot_dims = n_rot_dims - skip_first + assert n_rot_dims % 2 == 0, "n_rot_dims must be even for RoPE." + assert n_rot_dims <= D, "n_rot_dims cannot exceed feature dimension." + + # Normalize x and y to [0, 1] + H, W = image_shape[-2], image_shape[-1] + x_norm = coords[:, 1] / H + y_norm = coords[:, 2] / W + # Compute two angles for x and y + angle_x = 2 * math.pi * x_norm + angle_y = 2 * math.pi * y_norm + # Interleave angles for each feature pair + angles = torch.stack([angle_x, angle_y], dim=1).repeat(1, n_rot_dims // 2) + angles = angles.view(N, n_rot_dims) + # Prepare cos/sin + cos = torch.cos(angles) + sin = torch.sin(angles) + # Interleave features for rotation + # try: + features_rot = features[:, skip_first:n_rot_dims + skip_first].view(N, -1, 2) + # except Exception: + # breakpoint() + x_feat, y_feat = features_rot[..., 0], features_rot[..., 1] + x_rot = x_feat * cos[:, ::2] - y_feat * sin[:, ::2] + y_rot = x_feat * sin[:, ::2] + y_feat * cos[:, ::2] + rotated_part = torch.stack([x_rot, y_rot], dim=-1).reshape(N, n_rot_dims) + if n_rot_dims < D: + rotated = torch.cat([rotated_part, features[:, n_rot_dims:]], dim=1) + else: + rotated = rotated_part + return rotated # (n_objects, d) + + +class CTCDataAugPretrainedFeats(CTCData): + """CTCData with pretrained features.""" + def __init__( + self, + pretrained_n_augmentations: int = 3, + n_aug_workers: int = 8, + force_recompute=False, + aug_pipeline: PretrainedAugmentations = None, + load_from_disk: bool = False, + *args, + **kwargs + ): + """Args: + root (str): + Folder containing the CTC TRA folder. + ndim (int): + Number of dimensions of the data. Defaults to 2d + (if ndim=3 and data is two dimensional, it will be cast to 3D) + detection_folders: + List of relative paths to folder with detections. + Defaults to ["TRA"], which uses the ground truth detections. + window_size (int): + Window size for transformer. + slice_pct (tuple): + Slice the dataset by percentages (from, to). + augment (int): + if 0, no data augmentation. if > 0, defines level of data augmentation. + features (str): + Types of features to use. + sanity_dist (bool): + Use euclidian distance instead of the association matrix as a target. + crop_size (tuple): + Size of the crops to use for augmentation. If None, no cropping is used. + return_dense (bool): + Return dense masks and images in the data samples. + compress (bool): + Compress elements/remove img if not needed to save memory for large datasets + pretrained_backbone_config (PretrainedFeatureExtractorConfig): + Configuration for the pretrained backbone. + If mode is set to "pretrained_feats", this configuration is used to extract features. + Ignored otherwise. + pretrained_n_augmentations (int): + How many augmented versions of the pretrained model embeddings to create. + n_aug_workers (int): + Number of workers to use for offline augmentation. + load_from_disk (bool): + If True, the offline augmented windows are saved to disk and sampled from there. + If False, all windows are loaded into RAM and sampled from there. + force_recompute (bool): + If False, previously computed offline augmentations are loaded if available. + # pca_preprocessor (EmbeddingsPCACompression): + # PCA preprocessor for the pretrained features. + # If mode is set to "pretrained_feats", this is used to reduce the dimensionality of the features. + # Ignored otherwise. + """ + features = kwargs.get("features", None) + + if features is not None and not features == "pretrained_feats_aug": + raise ValueError("This class should only be used with pretrained_feats_aug features") + + self.n_augs = pretrained_n_augmentations + self.n_aug_workers = n_aug_workers + self.force_recompute = force_recompute + self.load_from_disk = load_from_disk + + from trackastra.data.pretrained_augmentations import ( + PretrainedMovementAugmentations, + ) + self.pretrained_feats_augmenter = PretrainedMovementAugmentations(rng_seed=42) if aug_pipeline is None else aug_pipeline + if not isinstance(self.pretrained_feats_augmenter, PretrainedAugmentations): + raise ValueError( + f"Augmentation pipeline must be of type PretrainedAugmentations, got {type(self.pretrained_feats_augmenter)}" + ) + logger.debug(self.pretrained_feats_augmenter) + self.augmented_feature_extractor = None + self.augmented_image_shapes = None # used to store the augmented image shapes, used to rotate features + self.save_windows = True + + self._aug_embeds_file = None # stores the augmented per-object embeddings + self.delete_augs_after_loading = False + # self.window_save_path = None + self._last_selected = None + self._rng = np.random.default_rng() + self._len = None + self._debug = False + + super().__init__(*args, **kwargs, load_immediately=False) + + if self.load_from_disk: + self.window_save_path = self._get_pretrained_features_save_path() / "windows" + self.window_save_path.mkdir(parents=True, exist_ok=True) + self.window_save_path = self.window_save_path / f"{self.config_hash}.zarr" + logger.debug(f"Windows will be saved to {self.window_save_path}") + else: + self.window_save_path = None + logger.debug("Windows will be loaded into RAM") + + if kwargs.get("load_immediately", True): # hook to delay loading if needed + self.start_loading() + start = default_timer() + else: + start = None + + logger.debug("Loading finished, clearing feature extractors...") + + # Clear pre-trained model + self.augmented_feature_extractor = None + # Clear windos as they are loaded from disk when __getitem__ is called + if self.load_from_disk: + self._get_ndim_and_nobj(start, self.windows) + self.windows = None + # Clear intermediate data + if self.delete_augs_after_loading and self._aug_embeds_file.exists(): + try: + self._aug_embeds_file.close() + except Exception as e: + logger.warning(f"Could not close HDF5 file: {e}") + try: + self._aug_embeds_file.unlink() + self._aug_embeds_file = None + except Exception as e: + logger.warning(f"Could not delete file {self._aug_embeds_file}: {e}") + logger.info("Feature extractors cleared.") + else: + self._get_ndim_and_nobj(start, self.windows) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + @property + def config(self): + cfg = super().config + cfg["pretrained_n_augmentations"] = self.n_augs + cfg["pretrained_augmentations"] = self.pretrained_feats_augmenter.get_signature() + return cfg + + @property + def feat_dim(self): + return self.pretrained_config.feat_dim + + def _init_features(self): + self.windows = self._load() + if self.load_from_disk: + self._save_windows() + else: + self._get_ndim_and_nobj(None, self.windows) + + def _setup_features_augs( + self + ): + logger.debug(f"Creating augmentations with level {self.augment_level}") + augmenter = wrfeat.AugmentationFactory.create_augmentation_pipeline(self.augment_level, return_type=wrfeat.WRAugPretrainedFeatures) + cropper = wrfeat.AugmentationFactory.create_cropper(self.crop_size, self.ndim, return_type=wrfeat.WRAugPretrainedFeatures) if self.crop_size is not None else None + + return augmenter, cropper + + def _get_ndim_and_nobj(self, start=None, windows=None): + if windows is not None: + self.ndim = windows[0]["coords"][0][0].shape[0] + self.n_objects = tuple(len(t["coords"][0]) for t in windows) + if self.save_windows and self.load_from_disk: + return + if len(self.windows) > 0: + self.ndim = self.windows[0]["coords"][0].shape[1] + self.n_objects = tuple(len(t["coords"][0]) for t in self.windows) + if start is None: + logger.info( + f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track" + f" windows from {self.root}\n" + ) + else: + logger.info( + f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track" + f" windows from {self.root} ({default_timer() - start:.1f}s)\n" + ) + else: + self.n_objects = 0 + logger.warning(f"Could not load any tracks from {self.root}") + + @classmethod + def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict): + raise NotImplementedError() + # self = cls(**train_args) + # start = default_timer() + + # self.windows = self._load() + # self.n_divs = self._get_ndivs() + + # if len(self.windows) > 0: + # self.ndim = self.windows[0]["coords"][0].shape[1] + # self.n_objects = tuple(len(t["coords"][0]) for t in self.windows) + # logger.info( + # f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track" + # f" windows from {self.root} ({default_timer() - start:.1f}s)\n" + # ) + # else: + # self.n_objects = 0 + # logger.warning(f"Could not load any tracks from {self.root}") + + # if self.compress: + # self._compress_data() + + def _load(self): + all_windows = [] + imgs = self._prepare_masks_and_imgs(return_orig_imgs=True) + + # self.properties_by_time = dict() + self.det_masks = dict() + logger.info("Loading detections") + if len(self.detection_folders) > 1: + raise NotImplementedError("Pretrained aug features with several folders is not supported yet") + + if self._load_windows() is not None: + return self.windows + + for _f in self.detection_folders: + det_folder = self.root / _f + + if det_folder == self.gt_mask_folder: + det_masks = self.gt_masks + logger.info("DET MASK:\tUsing GT masks") + # identity matching + ( + det_labels, + det_ts, + _, + ) = self._masks2properties(det_masks) + + det_gt_matching = { + t: {_l: _l for _l in set(np.unique(d)) - {0}} + for t, d in enumerate(det_masks) + } + else: + det_folder = self._guess_det_folder(root=self.root, suffix=_f) + if det_folder is None: + continue + logger.info(f"DET MASK (guessed):\t{det_folder}") + det_masks = self._load_tiffs(det_folder, dtype=np.int32) + det_masks = self._correct_gt_with_st( + det_folder, det_masks, dtype=np.int32 + ) + det_masks = self._check_dimensions(det_masks) + ( + det_labels, + det_ts, + _, + ) = self._masks2properties(det_masks) + # FIXME matching can be slow for big images + # raise NotImplementedError("Matching not implemented for 3d version") + det_gt_matching = { + t: { + _d: _gt + for _gt, _d in matching( + self.gt_masks[t], + det_masks[t], + threshold=0.3, + max_distance=16, + ) + } + for t in tqdm(range(len(det_masks)), leave=False, desc="Matching") + } + + self.det_masks[_f] = det_masks + # Setup feature extractor + self._setup_pretrained_feature_extractor() + + # Build augmentation pipeline + from trackastra.data.pretrained_features import FeatureExtractorAugWrapper + self.augmented_feature_extractor = FeatureExtractorAugWrapper( + extractor=self.feature_extractor, + augmenter=self.pretrained_feats_augmenter, + n_aug=self.n_augs, + force_recompute=self.force_recompute, + ) + self._aug_embeds_file = self.augmented_feature_extractor.get_save_path() + + # Compute features for all augmentations + augmented_dict = self.augmented_feature_extractor.compute_all_features( + images=imgs, + masks=det_masks, + clear_mem=not self.load_from_disk, + n_workers=self.n_aug_workers, + ) + self.augmented_image_shapes = self.augmented_feature_extractor.image_shape_reference + # logger.debug(f"AUG DICT keys : {augmented_dict.keys()}") + + _w = self._build_windows( + det_ts, + det_labels, + det_gt_matching, + augmented_dict + ) + all_windows.extend(_w) + + return all_windows + + def _build_windows(self, ts, labels, matching, augmented_dict): + windows = [] + window_size = self.window_size + n_frames = len(np.unique(ts)) + n_entries = self.n_augs + 1 + # augmented_dict structure : + # - aug_id: + # - metadata: dict, record of the applied augmentations and other metadata + # - data: the data for aug_id + # - t: frame between 0 and n_frames + # - lab: label of the detection + # - coords: coordinates of the detections for (t, lab) + # - features: dict of features for (t, lab) + # - feat_name_1: feature 1 for (t, lab) + # - ... + # - feat_name_n: feature n for (t, lab) + + for t1, t2 in tqdm( + zip(range(0, n_frames), range(window_size, n_frames + 1)), + total=n_frames - window_size + 1, + leave=False, + desc="Building windows", + ): + idx = (ts >= t1) & (ts < t2) + _ts = ts[idx] + _labels = labels[idx] + + _coords = {aug_id: [] for aug_id in range(n_entries)} + _features = {aug_id: {} for aug_id in range(n_entries)} + present_labels_per_aug = [set() for _ in range(n_entries)] + + for aug_id in range(n_entries): + for t in range(t1, t2): + labels_at_t = _labels[_ts == t] + data = augmented_dict[str(aug_id)]["data"] + + coords_at_t = [] + for lab in labels_at_t: + try: + coords_at_t.append(data[t][lab]["coords"]) + present_labels_per_aug[aug_id].add((t, lab)) + except KeyError: + continue + if len(coords_at_t) == 0: + coords_at_t = np.zeros((0, self.ndim), dtype=int) + else: + coords_at_t = np.stack(coords_at_t, axis=0) + _coords[aug_id].extend(coords_at_t) + + features_at_t = [] + for lab in labels_at_t: + try: + features_at_t.append(data[t][lab]["features"]) + except KeyError: + continue + if len(features_at_t) == 0: + features_at_t = {} + else: + features_at_t = [dict(f) for f in features_at_t] + for _f in features_at_t: + for k, v in _f.items(): + if k not in _features[aug_id]: + _features[aug_id][k] = [] + _features[aug_id][k].append(v) + + # --- Filter labels missing in any augmentation --- # + # (This can happen due to downsampling in pretrained features, + # label has too few pixels to have any valid associated features. + # If this occurs for too many labels, check the data and augmentation settings.) + common_labels = set.intersection(*present_labels_per_aug) + keep_mask = np.array([(t, lab) in common_labels for t, lab in zip(_ts, _labels)]) + if np.sum(~keep_mask) > 0: + missing_labels = set( + (t, lab) for t, lab in zip(_ts[~keep_mask], _labels[~keep_mask]) + ) + logger.warning( + f"Labels were removed from window {t1} to {t2}" + f", as those labels are missing in some augmentations. If this occurs for too many labels," + f" check the data and ensure augmentation settings are appropriate." + ) + logger.warning(f"Removed labels: {missing_labels}") + _labels = _labels[keep_mask] + _ts = _ts[keep_mask] + for aug_id in range(n_entries): + filtered_coords = [] + filtered_features = {k: [] for k in _features[aug_id].keys()} + idx_counter = 0 + for t in range(t1, t2): + labels_at_t = _labels[_ts == t] + n = len(labels_at_t) + filtered_coords.append(_coords[aug_id][idx_counter:idx_counter + n]) + for k in _features[aug_id].keys(): + filtered_features[k].extend(_features[aug_id][k][idx_counter:idx_counter + n]) + idx_counter += n + _coords[aug_id] = np.concatenate(filtered_coords, axis=0) if filtered_coords else np.zeros((0, self.ndim), dtype=np.float32) + for k in filtered_features: + _features[aug_id][k] = np.array(filtered_features[k], dtype=np.float32) + else: + # No missing labels, just convert to arrays as usual + for aug_id in range(n_entries): + _coords[aug_id] = np.array(_coords[aug_id], dtype=np.float32) + for k, v in _features[aug_id].items(): + _features[aug_id][k] = np.array(v, dtype=np.float32) + + if len(_labels) == 0: + # raise ValueError(f"No detections in sample {det_folder}:{t1}") # empty frames can happen + A = np.zeros((0, 0), dtype=bool) + else: + A = _ctc_assoc_matrix( + _labels, + _ts, + self.gt_graph, + matching, + ) + + w = dict( + coords=_coords, + t1=t1, + # img=self.imgs[t1:t2], + # mask=det_masks[t1:t2], + assoc_matrix=A, + labels=_labels, + timepoints=_ts, + features=_features, + ) + if not len(_coords) == n_entries or not len(_features) == n_entries: + raise ValueError(f"Number of coords {len(_coords)} or features {len(_features)} does not match number of augmentations {n_entries}") + windows.append(w) + + logger.debug(f"Built {len(windows)} track windows.\n") + return windows + + def _save_windows(self): + if self.window_save_path is not None: + self._len = len(self.windows) + logger.info(f"Saving windows to {self.window_save_path}") + mode = "w" if self.force_recompute else "a" + root = zarr.open_group(str(self.window_save_path), mode=mode) + for i, w in enumerate(self.windows): + group_name = f"window_{i}" + if group_name in root: + del root[group_name] + grp = root.create_group(group_name) + for aug_id in range(self.n_augs + 1): + grp.create_dataset(f"coords_{aug_id}", data=w["coords"][aug_id]) + features_group = grp.create_group(f"features_{aug_id}") + for k, v in w["features"][aug_id].items(): + features_group.create_dataset(k, data=v) + grp.create_dataset("labels", data=w["labels"]) + grp.create_dataset("timepoints", data=w["timepoints"]) + grp.create_dataset("assoc_matrix", data=w["assoc_matrix"]) + grp.attrs["t1"] = w["t1"] + else: + raise ValueError("No augmented embeddings zarr file set. Cannot save windows.") + + def _load_windows(self): + if not self.load_from_disk: + if getattr(self, "windows", None) is not None: + logger.debug("Windows already loaded into memory.") + return self.windows + return None + if self.window_save_path.exists() and not self.force_recompute: + self.windows = [] + logger.info(f"Loading windows from {self.window_save_path}") + root = zarr.open_group(str(self.window_save_path), mode="r") + group_names = sorted( + root.keys(), + key=lambda x: int(x.split("_")[1]) if x.startswith("window_") else x + ) + logger.debug(f"Found {len(group_names)} windows, loading...") + for w in group_names: + grp = root[w] + coords = [grp[f"coords_{aug_id}"][...] for aug_id in range(self.n_augs + 1)] + features = {} + for aug_id in range(self.n_augs + 1): + features[aug_id] = {k: grp[f"features_{aug_id}"][k][...] for k in grp[f"features_{aug_id}"].keys()} + labels = grp["labels"][...] + timepoints = grp["timepoints"][...] + assoc_matrix = grp["assoc_matrix"][...] + t1 = grp.attrs["t1"] + self.windows.append(dict( + coords=coords, + features=features, + labels=labels, + timepoints=timepoints, + assoc_matrix=assoc_matrix, + t1=t1, + )) + self._len = len(self.windows) + self._get_ndim_and_nobj(None, self.windows) + logger.info(f"Loaded {self._len} windows from {self.window_save_path}") + return self.windows + + @lru_cache + def _sample_from_memory(self, n: int, aug_choice: int = 0): + """When self.load_from_disk is False, sample a window from memory.""" + # logger.debug(f"Sampling window {n} with augmentation choice {aug_choice}") + track = self.windows[n] + # 0 is original, 1 to n_augs are the augmented versions + coords = track["coords"][aug_choice] + features = track["features"][aug_choice] + assoc_matrix = track["assoc_matrix"] + labels = track["labels"] + timepoints = track["timepoints"] + t1 = track["t1"] + + return coords, features, labels, timepoints, assoc_matrix, t1 + + def _sample_from_file(self, window_id: int, aug_choice: int = 0): + """When self.load_from_disk is True, sample a window from the saved zarr file.""" + root = zarr.open_group(str(self.window_save_path), mode="r") + grp = root[f"window_{window_id}"] + coords = grp[f"coords_{aug_choice}"][...] + features = {} + for k in grp[f"features_{aug_choice}"].keys(): + features[k] = grp[f"features_{aug_choice}"][k][...] + labels = grp["labels"][...] + timepoints = grp["timepoints"][...] + assoc_matrix = grp["assoc_matrix"][...] + t1 = grp.attrs["t1"] + return coords, features, labels, timepoints, assoc_matrix, t1 + + def _augment_item(self, item: wrfeat.WRAugPretrainedFeatures, labels, timepoints, assoc_matrix): + """Apply augmentations to the features.""" + # FIXME some arguments are redundant + if self.cropper is not None: + # Use only if there is at least one timepoint per detection + cropped_item, cropped_idx = self.cropper(item) + cropped_timepoints = item.timepoints[cropped_idx] + if len(np.unique(cropped_timepoints)) == self.window_size: + idx = cropped_idx + item = cropped_item + labels = labels[idx] + timepoints = timepoints[idx] + assoc_matrix = assoc_matrix[idx][:, idx] + # else: + # logger.debug("Skipping cropping") + + if self.augmenter is not None: + item = self.augmenter(item) + + return item, assoc_matrix + + @lru_cache + def get_augmented_image_shape(self, aug_choice: int): + try: + image_shape = self.augmented_image_shapes[aug_choice] + except KeyError: + root = zarr.open_group(str(self._aug_embeds_file), mode="r") + metadata_json = root[str(aug_choice)].attrs["metadata"] + metadata = json.loads(metadata_json) + image_shape = metadata["image_shape"] + return image_shape + + def __len__(self): + if self.save_windows and self.windows is None: + return self._len + else: + return len(self.windows) + + def __getitem__(self, n: int, return_dense=None): + if return_dense is None: + return_dense = self.return_dense + + random_aug_choice = self._rng.integers(0, self.n_augs + 1) + + if self.load_from_disk: + coords, features, labels, timepoints, assoc_matrix, _ = self._sample_from_file( + n, random_aug_choice + ) + else: + coords, features, labels, timepoints, assoc_matrix, _ = self._sample_from_memory( + n, random_aug_choice + ) + + # if return_dense and isinstance(mask, _CompressedArray): + # mask = CTCDataAugPretrainedFeats.decompress(mask) + # if return_dense and isinstance(img, _CompressedArray): + # img = CTCDataAugPretrainedFeats.decompress(img) + if isinstance(assoc_matrix, _CompressedArray): + assoc_matrix = CTCDataAugPretrainedFeats.decompress(assoc_matrix) + + coords = np.stack(coords, axis=0) + # features = np.stack(features, axis=0) + + augment_wrfeat = wrfeat.WRAugPretrainedFeatures.from_window( + features=features, + coords=coords, + timepoints=timepoints, + labels=labels, + ) + augmented_data, assoc_matrix = self._augment_item(augment_wrfeat, labels, timepoints, assoc_matrix) + if not isinstance(augmented_data, wrfeat.WRAugPretrainedFeatures): + raise ValueError("Augmented data is not a WRAugPretrainedFeatures. Check that augmenter return type is correct.") + features, pretrained_features, coords, timepoints, labels = augmented_data.to_window() + + shapes = [ + len(labels), + len(timepoints), + len(coords), + len(pretrained_features), + len(assoc_matrix), + ] + if features is not None: + shapes.append(len(features)) + if len(np.unique(shapes)) != 1: + raise ValueError(f"Shape mismatch: {shapes} (labs/timepoints/coords/features)") + + if coords.shape[-1] != self.ndim + 1: + raise ValueError(f"Coords shape mismatch: {coords.shape[-1]} != {self.ndim + 1}") + + # coords is already including time, simply remove min_time along the first axis + # coords[:, 0] -= min_time + + if self.max_tokens and len(timepoints) > self.max_tokens: + time_incs = np.where(timepoints - np.roll(timepoints, 1))[0] + n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1] + timepoints = timepoints[:n_elems] + labels = labels[:n_elems] + coords = coords[:n_elems] + if features is not None: + features = features[:n_elems] + if pretrained_features is not None: + pretrained_features = pretrained_features[:n_elems] + assoc_matrix = assoc_matrix[:n_elems, :n_elems] + logger.info( + f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}" + ) + + coords0 = torch.from_numpy(coords).float() + if features is not None: + features = torch.from_numpy(features).float() + pretrained_features = torch.from_numpy(pretrained_features).float() + assoc_matrix = torch.from_numpy(assoc_matrix.copy()).float() + labels = torch.from_numpy(labels).long() + timepoints = torch.from_numpy(timepoints).long() + + if self.augmenter is not None: + coords = coords0.clone() + coords[:, 1:] += torch.randint(0, 512, (1, self.ndim)) + else: + coords = coords0.clone() + + if self.rotate_feats: + image_shape = self.get_augmented_image_shape(random_aug_choice) + pretrained_features = CTCData.rotate_features( + pretrained_features, coords, image_shape, + n_rot_dims=self.pretrained_feat_dim # // 2 + ) + + res = dict( + features=features, + pretrained_features=pretrained_features, + coords0=coords0, + coords=coords, + assoc_matrix=assoc_matrix, + timepoints=timepoints, + labels=labels, + ) + + # if return_dense: + # if all([x is not None for x in img]): + # img = torch.from_numpy(img).float() + # res["img"] = img + + # mask = torch.from_numpy(mask.astype(int)).long() + # res["mask"] = mask return res @@ -1354,6 +2396,15 @@ def _assoc(A: np.ndarray, labels: np.ndarray, family: np.ndarray): A[i, j] = family[i, labels[j]] +def determine_ctc_class(dataset_kwargs: dict): + if "features" not in dataset_kwargs: + raise ValueError("features must be set in dataset_kwargs") + if dataset_kwargs["features"] == "pretrained_feats_aug": + return CTCDataAugPretrainedFeats + else: + return CTCData + + def _ctc_assoc_matrix(detections, ts, graph, matching): """Create the association matrix for a list of labels and a tracklet parent -> childrend graph. @@ -1471,9 +2522,16 @@ def collate_sequence_padding(batch): normal_keys = { "coords": 0, "features": 0, + "pretrained_features": 0, "labels": 0, # Not needed, remove for speed. "timepoints": -1, # There are real timepoints with t=0. -1 for distinction from that. } + actual_keys = { + k: v for k, v in normal_keys.items() if k in batch[0] and batch[0][k] is not None + } + none_keys = [ + k for k in normal_keys.keys() if k in batch[0] and batch[0][k] is None + ] n_pads = tuple(n_max_len - s for s in lens) batch_new = dict( ( @@ -1482,8 +2540,10 @@ def collate_sequence_padding(batch): [pad_tensor(x[k], n_max=n_max_len, value=v) for x in batch], dim=0 ), ) - for k, v in normal_keys.items() + for k, v in actual_keys.items() ) + for k in none_keys: + batch_new[k] = None batch_new["assoc_matrix"] = torch.stack( [ pad_tensor( @@ -1500,6 +2560,8 @@ def collate_sequence_padding(batch): pad_mask[i, n_max_len - n_pad :] = True batch_new["padding_mask"] = pad_mask.bool() + if torch.all(pad_mask.bool()): + raise ValueError("No valid entries for padding mask!") return batch_new diff --git a/trackastra/data/distributed.py b/trackastra/data/distributed.py index 04d0082..c0da0b0 100644 --- a/trackastra/data/distributed.py +++ b/trackastra/data/distributed.py @@ -20,26 +20,18 @@ DistributedSampler, ) -from .data import CTCData +from .data import CTCData, CTCDataAugPretrainedFeats, determine_ctc_class +from .utils import make_hashable logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -def cache_class(cachedir=None): +def cache_class(dataset_kwargs, cachedir=None): """A simple file cache for CTCData.""" - def make_hashable(obj): - if isinstance(obj, tuple | list): - return tuple(make_hashable(e) for e in obj) - elif isinstance(obj, Path): - return obj.as_posix() - elif isinstance(obj, dict): - return tuple(sorted((k, make_hashable(v)) for k, v in obj.items())) - else: - return obj - def hash_args_kwargs(*args, **kwargs): + # FIXME rotate_features arg should not be part of hash, if it changes cached data does not change hashable_args = tuple(make_hashable(arg) for arg in args) hashable_kwargs = make_hashable(kwargs) combined_serialized = json.dumps( @@ -49,7 +41,7 @@ def hash_args_kwargs(*args, **kwargs): return hash_obj.hexdigest() if cachedir is None: - return CTCData + return determine_ctc_class(dataset_kwargs) else: cachedir = Path(cachedir) @@ -60,11 +52,25 @@ def _wrapped(*args, **kwargs): if cache_file.exists(): logger.info(f"Loading cached dataset from {cache_file}") with open(cache_file, "rb") as f: - return pickle.load(f) + c = pickle.load(f) + # if c.pretrained_config is not None: + # cfg = c.pretrained_config + # if cfg.pca_preprocessor_path is not None: + # pca = EmbeddingsPCACompression.from_pretrained_cfg(cfg) + # pca.load_from_file(cfg.pca_preprocessor_path) + # c.pca_preprocessor = pca + + return c else: c = CTCData(*args, **kwargs) + if c.pretrained_config is not None: + c.pretrained_config = c.pretrained_config.to_dict() + c.feature_extractor = None + if isinstance(c, CTCDataAugPretrainedFeats): + c.augmented_feature_extractor = None logger.info(f"Saving cached dataset to {cache_file}") pickle.dump(c, open(cache_file, "wb")) + logger.debug(f"Cache file size: {cache_file.stat().st_size / 1e6:.2f} MB") return c return _wrapped @@ -156,6 +162,7 @@ def sample_batches(self, idx: Iterable[int]): # continue batch = idx_pool[j : j + self.batch_size] batches.append(batch) + return batches def __iter__(self): @@ -212,7 +219,7 @@ def __init__( input_val: list, cachedir: str, augment: int, - distributed:bool, + distributed: bool, dataset_kwargs: dict, sampler_kwargs: dict, loader_kwargs: dict, @@ -232,43 +239,83 @@ def prepare_data(self): Running on the main CPU process. """ - CTCData = cache_class(self.cachedir) + CachedData = cache_class( + dataset_kwargs=self.dataset_kwargs, + cachedir=self.cachedir + ) datasets = dict() + for split, inps in zip( ("train", "val"), (self.input_train, self.input_val), ): logger.info(f"Loading {split.upper()} data") start = default_timer() - datasets[split] = torch.utils.data.ConcatDataset( - CTCData( + local_kwargs = deepcopy(self.dataset_kwargs) + if self.dataset_kwargs.get("features") == "pretrained_feats_aug" and split == "val": + # do not compute augmented pretrained features for the val set + local_kwargs["features"] = "pretrained_feats" + + ctc_datasets = [ + CachedData( root=Path(inp), augment=self.augment if split == "train" else 0, - **self.dataset_kwargs, + **local_kwargs, ) for inp in inps + ] + [ + d.feature_extractor_save_path for d in ctc_datasets if split == "train" + ] + datasets[split] = torch.utils.data.ConcatDataset( + ctc_datasets ) + del ctc_datasets logger.info( f"Loaded {len(datasets[split])} {split.upper()} samples (in" f" {(default_timer() - start):.1f} s)\n\n" ) + # if self.dataset_kwargs.get("pretrained_backbone_config") is not None and split == "train": + # cfg = self.dataset_kwargs["pretrained_backbone_config"] + # if cfg.pca_preprocessor_path is not None: + # pca = EmbeddingsPCACompression.from_pretrained_cfg(cfg) + # embeddings_paths = [] + # for p in feature_extractor_save_paths: + # embeddings_paths.append(p) + # pca.fit_on_embeddings(embeddings_paths) + del datasets def setup(self, stage: str): - CTCData = cache_class(self.cachedir) + CachedData = cache_class( + dataset_kwargs=self.dataset_kwargs, + cachedir=self.cachedir + ) self.datasets = dict() + + # if self.dataset_kwargs.get("pretrained_backbone_config") is not None: + # cfg = self.dataset_kwargs["pretrained_backbone_config"] + # if cfg.pca_preprocessor_path is not None: + # pca = EmbeddingsPCACompression.from_pretrained_cfg(cfg) + # pca.load_from_file(cfg.pca_preprocessor_path) + # self.dataset_kwargs["pca_preprocessor"] = pca + for split, inps in zip( ("train", "val"), (self.input_train, self.input_val), ): logger.info(f"Loading {split.upper()} data") start = default_timer() + local_kwargs = deepcopy(self.dataset_kwargs) + if self.dataset_kwargs.get("features") == "pretrained_feats_aug" and split == "val": + # do not computea augmented pretrained features for the val set + local_kwargs["features"] = "pretrained_feats" self.datasets[split] = torch.utils.data.ConcatDataset( - CTCData( + CachedData( root=Path(inp), augment=self.augment if split == "train" else 0, - **self.dataset_kwargs, + **local_kwargs, ) for inp in inps ) @@ -286,7 +333,7 @@ def train_dataloader(self): ) batch_sampler = None else: - sampler=None + sampler = None batch_sampler = BalancedBatchSampler( self.datasets["train"], **self.sampler_kwargs, diff --git a/trackastra/data/features.py b/trackastra/data/features.py index 0c02f78..b4a489e 100644 --- a/trackastra/data/features.py +++ b/trackastra/data/features.py @@ -1,9 +1,13 @@ import itertools +import logging import numpy as np import pandas as pd from skimage.measure import regionprops_table +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + # the property keys that are supported for 2 and 3 dim _PROPERTIES = { diff --git a/trackastra/data/pretrained_augmentations.py b/trackastra/data/pretrained_augmentations.py new file mode 100644 index 0000000..f56a0f5 --- /dev/null +++ b/trackastra/data/pretrained_augmentations.py @@ -0,0 +1,418 @@ +from abc import ABC, abstractmethod +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F +from torchvision import tv_tensors +from torchvision.transforms import v2 as transforms + +from trackastra.utils.utils import percentile_norm + + +class BaseAugmentation(ABC): + """Base class for windowed region augmentations.""" + def __init__(self, p: float = 0.5, rng_seed=None): + self._p = p + self._rng = np.random.RandomState(rng_seed) + self.applied_record = {} + self.signature = {} + + def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask): + if self._p is None or self._rng.rand() < self._p: + aug = self._get_aug() + return aug(images, masks) + return images, masks + + @abstractmethod + def _get_aug(self) -> transforms.Compose: + raise NotImplementedError() + + +class FlipAugment(BaseAugmentation): + def __init__(self, p_horizontal: float = 0.5, p_vertical: float = 0.5, rng_seed=None): + super().__init__(p=None, rng_seed=rng_seed) + self._p_horizontal = p_horizontal + self._p_vertical = p_vertical + self.signature = { + "FlipAugment": { + "horizontal": self._p_horizontal, + "vertical": self._p_vertical + } + } + + def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask): + if self._rng.rand() < self._p_horizontal: + images = transforms.functional.hflip(images) + masks = transforms.functional.hflip(masks) + self.applied_record["hflip"] = True + else: + self.applied_record["hflip"] = False + if self._rng.rand() < self._p_vertical: + images = transforms.functional.vflip(images) + masks = transforms.functional.vflip(masks) + self.applied_record["vflip"] = True + else: + self.applied_record["vflip"] = False + return images, masks + + def _get_aug(self) -> transforms.Compose: + raise NotImplementedError("Use __call__ instead.") + + +class RotAugment(BaseAugmentation): + + def __init__(self, p: float = 0.5, degrees: int = 15, rng_seed=None): + super().__init__(p, rng_seed=rng_seed) + self.degrees = degrees + self.signature = { + "RotAugment": { + "p": self._p, + "degrees": self.degrees, + } + } + + def _get_aug(self): + self.applied_record["rotation"] = self.degrees + t = transforms.RandomRotation(degrees=self.degrees) + return t + + +class Rot90Augment(BaseAugmentation): + + def __init__(self, p=0.5, rng_seed=None): + super().__init__(p, rng_seed=rng_seed) + self.signature = { + "Rot90Augment": { + "p": self._p, + } + } + + def __call__(self, images, masks): + if self._rng.rand() > self._p: + return images, masks + angle = self._get_aug() + images = transforms.functional.rotate(images, angle, expand=True) + masks = transforms.functional.rotate(masks, angle, expand=True) + return images, masks + + def _get_aug(self): + angle = self._rng.choice([90, 180, 270]) + self.applied_record["rot90"] = int(angle) + return angle + + +class BrightnessJitter(BaseAugmentation): + + def __init__(self, bright_shift: float = 0.5, contrast_shift: float = 0.5, rng_seed=None): + super().__init__(p=None, rng_seed=rng_seed) + self._b_shift = bright_shift + self._c_shift = contrast_shift + self.signature = { + "BrightnessJitter": { + "brightness_shift": self._b_shift, + "contrast_shift": self._c_shift + } + } + + def _get_aug(self): + if self._b_shift is not None: + bright = self._rng.uniform(0, self._b_shift) + else: + bright = None + self.applied_record["brightness_jitter"] = bright + if self._c_shift is not None: + contrast = self._rng.uniform(0, self._c_shift) + else: + contrast = None + self.applied_record["contrast_jitter"] = contrast + return transforms.ColorJitter(brightness=bright, contrast=contrast) + + +class AddGaussianNoise(BaseAugmentation): + def __init__(self, mean: float = 0.0, std: float = 0.1, rng_seed=None): + super().__init__(p=None, rng_seed=rng_seed) + self.mean = mean + self.sigma = std + self.signature = { + "AddGaussianNoise": { + "mean": self.mean, + "std": self.sigma + } + } + + def _get_aug(self): + # sample random mean/std + self.applied_record["gaussian_noise"] = (self.mean, self.sigma) + return transforms.GaussianNoise(mean=self.mean, sigma=self.sigma) + + def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask): + aug = self._get_aug() + images = aug(images) + return images, masks + + +class GaussianBlur(BaseAugmentation): + def __init__(self, kernel_size: int = 3, sigma: tuple[float] = (0.01, 1.0), rng_seed=None): + super().__init__(p=None, rng_seed=rng_seed) + self.kernel_size = kernel_size + self.sigma = sigma + self.signature = { + "GaussianBlur": { + "kernel_size": self.kernel_size, + "sigma": self.sigma + } + } + + def _get_aug(self): + self.applied_record["gaussian_blur"] = (self.kernel_size, self.sigma) + return transforms.GaussianBlur(kernel_size=self.kernel_size, sigma=self.sigma) + + def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask): + aug = self._get_aug() + images = aug(images) + return images, masks + + +class RandomAffine(BaseAugmentation): + def __init__(self, degrees: float = 0.0, translate: tuple[float, float] = (0.0, 0.0), scale: tuple[float, float] = (1.0, 1.0), rng_seed=None): + super().__init__(p=None, rng_seed=rng_seed) + self.degrees = degrees + self.translate = translate + self.scale = scale + self.signature = { + "RandomAffine": { + "degrees": self.degrees, + "translate": self.translate, + "scale": self.scale + } + } + + def _get_aug(self): + return transforms.RandomAffine(degrees=self.degrees, translate=self.translate, scale=self.scale) + + def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask): + aug = self._get_aug() + images, masks = aug(images, masks) + return images, masks + + +class ElasticTransform(BaseAugmentation): + def __init__(self, p=0.5, alpha: float = 10.0, sigma: float = 0.5, rng_seed=None): + super().__init__(p=p, rng_seed=rng_seed) + self.alpha = alpha + self.sigma = sigma + self.signature = { + "ElasticTransform": { + "p": self._p, + "alpha": self.alpha, + "sigma": self.sigma + } + } + + def _get_aug(self): + alpha = self._rng.uniform(0, self.alpha) + sigma = self._rng.uniform(0, self.sigma) + self.applied_record["elastic_transform"] = (alpha, sigma) + return transforms.ElasticTransform(alpha=alpha, sigma=sigma) + + +class RandomScale(BaseAugmentation): + def __init__(self, p: float = 0.9, max_scale: float = 1.0, min_scale=0.8, preserve_size=False, rng_seed=None): + super().__init__(p=p, rng_seed=rng_seed) + self.min_scale = min_scale + self.max_scale = max_scale + self.preserve_size = preserve_size + self.signature = { + "RandomScale": { + "p": self._p, + "min_scale": self.min_scale, + "max_scale": self.max_scale, + "preserve_size": self.preserve_size + } + } + + def _get_aug(self): + scale = self._rng.uniform(self.min_scale, self.max_scale) + self.applied_record["random_scale"] = scale + return scale + + def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask): + if self._p is None or self._rng.rand() < self._p: + scale = self._get_aug() + orig_h, orig_w = images.shape[-2], images.shape[-1] + new_h, new_w = int(orig_h * scale), int(orig_w * scale) + + # Resize images and masks + images_scaled = F.interpolate(images, size=(new_h, new_w), mode="bilinear", align_corners=False) + masks_scaled = F.interpolate(masks.float(), size=(new_h, new_w), mode="nearest").long() + + if self.preserve_size: + pad_h = max(orig_h - new_h, 0) + pad_w = max(orig_w - new_w, 0) + pad = [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] # left, right, top, bottom + + images_scaled = F.pad(images_scaled, pad, mode="constant", value=0) + masks_scaled = F.pad(masks_scaled, pad, mode="constant", value=0) + + # If scaled image is larger, crop to original size + images_scaled = images_scaled[..., :orig_h, :orig_w] + masks_scaled = masks_scaled[..., :orig_h, :orig_w] + + return images_scaled, masks_scaled + return images, masks + + +class IdentityAugment(BaseAugmentation): + """Identity augmentation for debugging purposes.""" + def __init__(self, p: float = 1.0, rng_seed=None): + super().__init__(p=p, rng_seed=rng_seed) + self.signature = { + "IdentityAugment": { + "p": self._p + } + } + + def _get_aug(self): + self.applied_record["identity"] = True + return transforms.Lambda(lambda x: x) + + def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask): + return images, masks + + +class PretrainedAugmentations: + """Augmentation pipeline to get augmented copies of model embeddings.""" + default_normalize = percentile_norm + + def __init__(self, rng_seed=None, normalize=True, shuffle=True): + self.aug_record = {} + self.aug_list = [ + # IdentityAugment(rng_seed=rng_seed), # debugging + BrightnessJitter(bright_shift=0.3, contrast_shift=0.3, rng_seed=rng_seed), + FlipAugment(p_horizontal=0.5, p_vertical=0.5, rng_seed=rng_seed), + # RotAugment(degrees=10, rng_seed=rng_seed), + Rot90Augment(p=0.5, rng_seed=rng_seed), + # AddGaussianNoise(mean=0.0, std=0.02, rng_seed=rng_seed), + RandomScale(rng_seed=rng_seed), + GaussianBlur(kernel_size=5, sigma=(0.01, 2.0), rng_seed=rng_seed), + # ElasticTransform(p=0.25, alpha=10.0, sigma=0.5, rng_seed=rng_seed), + # RandomAffine(degrees=0.0, translate=(0.1, 0.1), scale=(0.9, 1.1), rng_seed=rng_seed), + ] + self._aug = None + self._rng = np.random.RandomState(rng_seed) + self.normalize = normalize + self.image_shape = None + self.shuffle = shuffle + + def __call__(self, images: torch.Tensor, masks: tv_tensors.Mask, normalize_func=None) -> tuple[torch.Tensor, tv_tensors.Mask, dict]: + """Applies the augmentations to the images.""" + images, masks = self.preprocess(images, masks, normalize_func=normalize_func) + + if self.shuffle: + aug_list = self.aug_list.copy() + self._rng.shuffle(aug_list) + else: + aug_list = self.aug_list + self._aug = transforms.Compose(aug_list) + + images = torch.unsqueeze(images, dim=1) # add channel dimension (T, C, H, W) for augmentation + masks = torch.unsqueeze(masks, dim=1) # add channel dimension (T, C, H, W) for augmentation + + images, masks = self._aug(images, masks) + if torch.isnan(images).any() or torch.isnan(masks).any(): + raise RuntimeError("NaN values found in images or masks after augmentation.") + self.image_shape = images.shape + # NOTE : most models do require 3 channels, but this will be done in FeatureExtractor, so the output is squeezed + return images.squeeze(), masks.squeeze(), self.gather_records() + + def get_signature(self): + """Returns the signature of the augmentations.""" + if self._aug is None: + self._aug = transforms.Compose(self.aug_list) + signatures = OrderedDict() + for aug in self.aug_list: + if aug.signature: + signatures.update(aug.signature) + return signatures + + def __add__(self, other): + """Combines two augmentation pipelines.""" + if not isinstance(other, PretrainedAugmentations): + raise TypeError("Can only combine with another PretrainedAugmentations instance.") + combined = PretrainedAugmentations(rng_seed=self._rng.seed, normalize=self.normalize, shuffle=self.shuffle) + combined.aug_list = self.aug_list + other.aug_list + return combined + + def __repr__(self): + sig = self.get_signature() + msg = "Augmentation pipeline" + for aug_name, params in sig.items(): + msg += f"- {aug_name}: {params}\n" + msg += "_" * 40 + "\n" + return msg.strip() + + def preprocess(self, images, masks, normalize_func=None): + if not len(images.shape) == 3: + raise ValueError(f"Images must be tensor of shape (T, H, W), got {len(images.shape)}D tensor.") + if not len(masks.shape) == 3: + raise ValueError(f"Masks must be tensor of shape (T, H, W), got {len(masks.shape)}D tensor.") + + if not isinstance(images, torch.Tensor): + try: + images = torch.tensor(images, dtype=torch.float32) + except Exception as e: + raise ValueError(f"Failed to convert images to tensor: {e}") + if not isinstance(masks, tv_tensors.Mask): + try: + masks = tv_tensors.Mask(masks, dtype=torch.int64) + except Exception as e: + raise ValueError(f"Failed to convert masks to tensor: {e}") + + if normalize_func is not None: + if not callable(normalize_func): + raise ValueError("normalize_func must be a callable function.") + images = normalize_func(images) + + return images, masks + + def gather_records(self): + """Gathers the applied augmentation records.""" + self.aug_record = {} + for aug in self.aug_list: + self.aug_record.update(aug.applied_record) + self.aug_record["image_shape"] = self.image_shape + return self.aug_record + + +class PretrainedMovementAugmentations(PretrainedAugmentations): + """Augmentation pipeline for movement embeddings.""" + def __init__(self, rng_seed=None, normalize=True, shuffle=True): + super().__init__(rng_seed=rng_seed, normalize=normalize, shuffle=shuffle) + self.aug_list = [ + FlipAugment(p_horizontal=0.5, p_vertical=0.5, rng_seed=rng_seed), + Rot90Augment(p=0.5, rng_seed=rng_seed), + RandomScale(rng_seed=rng_seed), + ] + + +class PretrainedIntensityAugmentations(PretrainedAugmentations): + """Augmentation pipeline for intensity embeddings.""" + def __init__(self, rng_seed=None, normalize=True, shuffle=True): + super().__init__(rng_seed=rng_seed, normalize=normalize, shuffle=shuffle) + self.aug_list = [ + BrightnessJitter(bright_shift=0.05, contrast_shift=0.05, rng_seed=rng_seed), + # FlipAugment(p_horizontal=0.5, p_vertical=0.5, rng_seed=rng_seed), + # Rot90Augment(p=0.5, rng_seed=rng_seed), + AddGaussianNoise(mean=0.0, std=0.02, rng_seed=rng_seed), + # RandomScale(rng_seed=rng_seed), + ] + + +class IdentityAugmentations(PretrainedAugmentations): + """Identity augmentation pipeline for debugging.""" + def __init__(self, rng_seed=None, normalize=True, shuffle=True): + super().__init__(rng_seed=rng_seed, normalize=normalize, shuffle=shuffle) + self.aug_list = [ + IdentityAugment(p=1.0, rng_seed=rng_seed), + ] \ No newline at end of file diff --git a/trackastra/data/pretrained_features.py b/trackastra/data/pretrained_features.py new file mode 100644 index 0000000..0165967 --- /dev/null +++ b/trackastra/data/pretrained_features.py @@ -0,0 +1,2122 @@ +import json +import logging +import os +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from functools import partial +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +import joblib +import numpy as np +import torch +import torch.nn.functional as F +import zarr +from numcodecs import Blosc +from sam2.sam2_image_predictor import SAM2ImagePredictor +from skimage.measure import regionprops +from tqdm import tqdm +from transformers import ( + AutoImageProcessor, + # Dinov2Config, + # Dinov2Model, + AutoModel, + HieraConfig, + HieraModel, + SamModel, + SamProcessor, +) + +from trackastra.data import wrfeat +from trackastra.utils.utils import percentile_norm + +if TYPE_CHECKING: + from trackastra.data.pretrained_augmentations import PretrainedAugmentations + +try: + from micro_sam.util import get_sam_model as get_microsam_model + MICRO_SAM_AVAILABLE = True +except ImportError: + MICRO_SAM_AVAILABLE = False + +try: + from tarrow.models import TimeArrowNet + from tarrow.utils import normalize as tap_normalize + TARROW_AVAILABLE = True +except ImportError: + TARROW_AVAILABLE = False + +try: + from cellpose import transforms as cp_transforms + from cellpose.vit_sam import Transformer as CellposeSAM + CELLPOSE_AVAILABLE = True +except ImportError: + CELLPOSE_AVAILABLE = False + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +# Updated with actual class after each definition +# See register_backbone decorator +AVAILABLE_PRETRAINED_BACKBONES = {} + +PretrainedFeatsExtractionMode = Literal[ + # "exact_patch", # Uses the image patch centered on the detection for embedding + "nearest_patch", # Runs on whole image, then finds the nearest patch to the detection in the embedding + "mean_patches_bbox", # Runs on whole image, then averages the embeddings of all patches that intersect with the detection's bounding box + "mean_patches_exact", # Runs on whole image, then averages the embeddings of all patches that intersect with the detection + "max_patches_bbox", # Runs on whole image, then takes the maximum for each feature dimension of all patches that intersect with the detection + "max_patches_exact", # Runs on whole image, then takes the maximum for each feature dimension of all patches that intersect with the detection + "median_patches_exact", # Runs on whole image, then takes the median for each feature dimension of all patches that intersect with the detection +] + +PretrainedBackboneType = Literal[ # cannot unpack this directly in python < 3.11 so it has to be copied + "facebook/hiera-tiny-224-hf", # 768 + "facebook/dinov2-base", # 768 + "facebook/sam-vit-base", # 256 + "facebook/sam2-hiera-large", # 256 + "facebook/sam2.1-hiera-base-plus", # 256 + "facebookresearch/co-tracker", # 128 + "microsam/vit_b_lm", + "microsam/vit_l_lm", + "weigertlab/tarrow", # arbitrary. default 32 + "mouseland/cellpose-sam", # 192 + "facebook/sam2.1-hiera-base-plus/highres", + "debug/random", + "debug/encoded_labels", # 32 +] + + +def register_backbone(model_name, feat_dim): + def decorator(cls): + AVAILABLE_PRETRAINED_BACKBONES[model_name] = { + "class": cls, + "feat_dim": feat_dim, + } + return cls + return decorator + + +# Feature extraction from pretrained models +# Meant to wrap any transformers model +# >NOTE : currently not applicable to 3D data +# (but aggregation-based modes may be adapted eventually) +import time +from functools import wraps + + +def average_time_decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not hasattr(wrapper, 'total_time'): + wrapper.total_time = 0 + wrapper.call_count = 0 + + start_time = time.time() + result = func(*args, **kwargs) + elapsed_time = time.time() - start_time + + wrapper.total_time += elapsed_time + wrapper.call_count += 1 + average_time = wrapper.total_time / wrapper.call_count + + print(f"Average time taken by {func.__name__}: {average_time:.6f} seconds over {wrapper.call_count} calls") + + return result + + return wrapper + + +# Configs for pretrained models ### + +@dataclass +class PretrainedFeatureExtractorConfig: + """model_name (str): + Specify the pretrained backbone to use. + model_path (str | Path): + Path to the pretrained model. + save_path (str | Path): + Specify the path to save the embeddings. + batch_size (int): + Specify the batch size to use for the model. + mode (str): + Specify the mode to use for the model. + Currently available modes are "nearest_patch", "mean_patches_bbox", "mean_patches_exact", "max_patches_bbox", "max_patches_exact". + normalize_embeddings (bool): + Whether to normalize the embeddings (divide by the norm). + device (str): + Specify the device to use for the model. + If not set and "pretrained_feats" is used, the device is automatically set by default to "cuda", "mps" or "cpu" as available. + n_augmented_copies (int): + How many augmented copies of the embeddings to create. If 0, only the original embeddings are saved. Creates n+1 embeddings entries total. + additional_features (str): + Specify any additional features (from regionprops) to include in the extraction process. See WRFeat documentation for available features. Unused if None. + pca_components (int): + Specify the number of PCA components to use for dimensionality reduction of the features. Unused if None. + pca_preprocessor_path (str | Path): + Specify the path to the pickled PCA preprocessor. This is used to transform the features to a reduced PCA feature space. + """ + model_name: PretrainedBackboneType + model_path: str | Path = None + save_path: str | Path = None + batch_size: int = 4 + mode: PretrainedFeatsExtractionMode = "nearest_patch" + normalize_embeddings: bool = True # whether to normalize the embeddings (divide by the norm) + device: str | None = None + feat_dim: int = None + additional_features: str | None = None # for regionprops features + additional_feat_dim: int = 0 # for regionprops features + n_augmented_copies: int = 0 # number of augmented copies to create + seed: int | None = None # seed for debug/random + # pca_components: int = None # for PCA reduction of the features + # pca_preprocessor_path: str | Path = None # for PCA preprocessor path + # apply_rope: bool = False # whether to apply RoPE-like rotation to the features based on coordinates + + def __post_init__(self): + self._guess_device() + self.model_path = self._check_path(self.model_path) + self.save_path = self._check_path(self.save_path) + # self.pca_preprocessor_path = self._check_path(self.pca_preprocessor_path) + self._check_model_availability() + + def _check_path(self, path): + if path is not None and not isinstance(path, str | Path): + raise ValueError(f"Path must be a string or Path object, got {type(path)}.") + if isinstance(path, str): + return Path(path).resolve() + return path + + def _check_model_availability(self): + if self.model_name not in AVAILABLE_PRETRAINED_BACKBONES.keys(): + raise ValueError(f"Model {self.model_name} is not available for feature extraction.") + if self.model_name == "weigertlab/tarrow": + if not TARROW_AVAILABLE: + raise ImportError("TArrow is not available. Please install it to use this model.") + elif self.model_path is None: + raise ValueError("Model path must be specified for TArrow.") + _, self.feat_dim = TAPFeatures._load_model_from_path(self.model_path) + else: + self.feat_dim = AVAILABLE_PRETRAINED_BACKBONES[self.model_name]["feat_dim"] + if self.additional_features is not None: + # TODO if this ever accepts 3D data this will be incorrect + self.additional_feat_dim = wrfeat.WRFeatures.PROPERTIES_DIMS[ + self.additional_features + ][2] + if self.additional_features not in wrfeat._PROPERTIES: + raise ValueError(f"Additional feature {self.additional_features} is not valid.") + # if self.pca_components is not None: + # self.feat_dim = self.pca_components + + def _guess_device(self): + if self.device is None: + should_use_mps = ( + torch.backends.mps.is_available() + and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None + and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0" + ) + self.device = ( + "cuda" + if torch.cuda.is_available() + else ( + "mps" + if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") + else "cpu" + ) + ) + + try: + torch.device(self.device) # check if device is valid + except Exception as e: + raise ValueError(f"Invalid device: {self.device}") from e + + def to_dict(self): + return asdict(self) + + @classmethod + def from_dict(cls, config_dict): + return cls(**config_dict) + +# Feature extractors ### + + +class FeatureExtractor(ABC): + model_name = None + _available_backbones = None + + def __init__( + self, + image_size: tuple[int, int], + save_path: str | Path, + batch_size: int = 4, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + mode: PretrainedFeatsExtractionMode = "nearest_patch", + normalize_embeddings: bool = True, + **kwargs, + ): + """ + Initializes a pretrained model feature extractor with the given parameters. + Consumes images, computes embeddings of shape T, H*W, N (N the model feature dimension), + and generates n_regions x N embeddings for each object in each frame using the specified mode. + + Args: + - image_size (tuple[int, int]): Size of the input images (height, width). + - save_path (str | Path): Path to save the embeddings. + - batch_size (int): Batch size to use for the model. + - device (str): Device to use for the model. Defaults to "cuda" if available, otherwise "cpu". + - mode (str): Mode to use for the model. Defaults to "nearest_patch". See type PretrainedFeatsExtractionMode for available modes. + - normalize_embeddings (bool): Whether to normalize the embeddings (divide by the norm). Defaults to True. For aggregation-based modes, this is applied before aggregation. + """ + # Image processor extra args + # Modify as needed in subclasses + self.im_proc_kwargs = { + "do_rescale": False, + "do_normalize": False, + "do_resize": True, + "return_tensors": "pt", + "do_center_crop": False, + } + # Model specs + self.model = None + self._input_size: tuple[int] = None + self._final_grid_size: tuple[int] = None + self.n_channels: int = None + self.hidden_state_size: int = None + self.model_patch_size: int = None + # Data specs + self.orig_image_size = image_size + self.orig_n_channels = 1 + # Batch options and preprocessing + self.do_normalize = True + self.rescale_batches = False + self.channel_first = True + self.batch_return_type: Literal["list[np.ndarray]", "np.ndarray", "torch.Tensor"] = "np.ndarray" + self.batch_size = batch_size + self.device = device + # Parameters for embedding extraction + self.mode = mode + # If running FeatureExtractor in a parallelized context, set to False to avoid overhead + # from spawning many threads within the parallelized context. + self.parallel = True + self.additional_features = None + self.apply_rope = False # deprecated, use "rotate_features" in CTCData # TODO remove + self.normalize_embeddings = normalize_embeddings + # Saving parameters + self.save_path: str | Path = save_path + self.do_save = True + self.force_recompute = False + + self.embeddings = None + self._debug_view = False + # self._debug = True + + if not isinstance(self.save_path, Path): + self.save_path = Path(self.save_path) + + if not self.save_path.exists(): + self.save_path.mkdir(parents=True, exist_ok=True) + + @property + def input_size(self): + return self._input_size + + @input_size.setter + def input_size(self, value: int | tuple[int]): + if isinstance(value, int): + value = (value, value) + elif isinstance(value, tuple): + if len(value) != 2: + raise ValueError("Input size must be a tuple of length 2.") + else: + raise ValueError("Input size must be an int or a tuple of ints.") + self._input_size = value + self._set_model_patch_size() + + @property + def final_grid_size(self): + return self._final_grid_size + + @final_grid_size.setter + def final_grid_size(self, value: int | tuple[int]): + if isinstance(value, int): + value = (value, value) + elif isinstance(value, tuple): + if len(value) != 2: + raise ValueError("Final grid size must be a tuple of length 2.") + else: + raise ValueError("Final grid size must be an int or a tuple of ints.") + self._final_grid_size = value + self._set_model_patch_size() + + @property + def model_name_path(self): + return self.model_name.replace("/", "-") + + @staticmethod + def _load_model_from_path(self) -> tuple[torch.nn.Module, int]: + """Loads the model from the specified path. Returns the model and the model feature dimension (e.g. 256 for SAM2).""" + raise NotImplementedError("This model currently only supports being loaded from huggingface's hub.") + + @classmethod + def from_model_name(cls, + model_name: PretrainedBackboneType, + image_shape: tuple[int, int], + save_path: str | Path, + device: torch.device = "cuda" if torch.cuda.is_available() else "cpu", + mode="nearest_patch", + additional_features=None, + model_folder=None, + ): + cls._available_backbones = AVAILABLE_PRETRAINED_BACKBONES + if model_name not in cls._available_backbones: + raise ValueError(f"Model {model_name} is not available for feature extraction.") + logger.info(f"Using model {model_name} with mode {mode} for pretrained feature extraction.") + backbone = cls._available_backbones[model_name]["class"] + backbone.model_name = model_name + model = backbone( + image_size=image_shape, + save_path=save_path, + device=device, + mode=mode, + model_folder=model_folder, + ) + model.additional_features = additional_features + return model + + @classmethod + def from_config(cls, config: PretrainedFeatureExtractorConfig, image_shape: tuple[int, int], save_path: str | Path | None = None): + cls._available_backbones = AVAILABLE_PRETRAINED_BACKBONES + if config.model_name not in cls._available_backbones: + raise ValueError(f"Model {config.model_name} is not available for feature extraction.") + logger.info(f"Using model {config.model_name} with mode {config.mode} for pretrained feature extraction.") + backbone = cls._available_backbones[config.model_name]["class"] + + parts = config.model_name.split("/") + if len(parts) > 2: + model_name = "/".join(parts[:2]) + else: + model_name = config.model_name + backbone.model_name = model_name + + model = backbone( + image_size=image_shape, + save_path=save_path if save_path is not None else config.save_path, + batch_size=config.batch_size, + device=config.device, + mode=config.mode, + additional_features=config.additional_features, + model_folder=config.model_path, + normalize_embeddings=config.normalize_embeddings, + seed=config.seed if hasattr(config, "seed") else None, + # n_augmented_copies=config.n_augmented_copies, + # aug_pipeline=PretrainedAugmentations() if config.n_augmented_copies > 0 else None, + ) + model.additional_features = config.additional_features + model.normalize_embeddings = config.normalize_embeddings + # model.apply_rope = config.apply_rope + return model + + def clear_model(self): + """Clears the model from memory.""" + if self.model is not None: + del self.model + self.model = None + torch.cuda.empty_cache() + logger.info("Model cleared from memory.") + else: + logger.warning("No model to clear from memory.") + + def _set_model_patch_size(self): + if self.final_grid_size is None or self.input_size is None: + self.model_patch_size = None + else: + if not isinstance(self.input_size, tuple): + raise ValueError("Input size must be a tuple of ints.") + self.model_patch_size = ( + self.input_size[0] // self.final_grid_size[0], + self.input_size[1] // self.final_grid_size[1], + ) + if self.model_patch_size[0] <= 0 or self.model_patch_size[1] <= 0: + raise ValueError("Model patch size must be greater than 0.") + + def compute_region_features( + self, + coords, + masks=None, + timepoints=None, + labels=None, + # embeddings=None + ) -> torch.Tensor: + feats = torch.zeros(len(coords), self.hidden_state_size, device=self.device) + match self.mode: + case "nearest_patch": + feats = self._nearest_patches(coords, masks, norm=self.normalize_embeddings) + return feats # Return early, nothing else to do + case mode if mode.endswith("_patches_exact"): + if masks is None or labels is None or timepoints is None: + raise ValueError("Masks and labels must be provided for the chosen patch mode.") + feats_func = partial(self._agg_patches_exact, masks, timepoints, labels, norm=self.normalize_embeddings) + case mode if mode.endswith("_patches_bbox"): + if masks is None or labels is None or timepoints is None: + raise ValueError("Masks and labels must be provided for the chosen patch mode.") + feats_func = partial(self._agg_patches_bbox, masks, timepoints, labels, norm=self.normalize_embeddings) + case _: + raise NotImplementedError(f"Mode {self.mode} is not implemented.") + + # Only for aggregation modes + if "max" in self.mode: + feats = feats_func(agg=torch.max) + elif "mean" in self.mode: + feats = feats_func(agg=torch.mean) + elif "median" in self.mode: + feats = feats_func(agg=torch.median) + else: + raise NotImplementedError(f"Unknown aggregation for mode {self.mode}") + + assert feats.shape == (len(coords), self.hidden_state_size) + return feats # (n_regions, embedding_size) + + def precompute_image_embeddings(self, images, **kwargs): # , windows, window_size): + """Precomputes embeddings for all images.""" + missing = self._check_missing_embeddings() + all_embeddings = torch.zeros(len(images), self.final_grid_size[0] * self.final_grid_size[1], self.hidden_state_size, device=self.device) + if missing: + for ts, batches in tqdm(self._prepare_batches(images), total=len(images) // self.batch_size, desc="Computing embeddings", leave=False): + embeddings = self._run_model(batches, **kwargs) + if torch.any(embeddings.isnan()): + raise RuntimeError("NaN values found in features.") + # logger.debug(f"Embeddings shape: {embeddings.shape}") + all_embeddings[ts] = embeddings.to(torch.float32) + assert embeddings.shape[-1] == self.hidden_state_size + self.embeddings = all_embeddings + self._save_features(all_embeddings) + # logger.debug(f"Precomputed embeddings shape: {self.embeddings.shape}") + return self.embeddings + + def _extract_region_embeddings(self, all_frames_embeddings, window, start_index, remaining=None): + window_coords = window["coords"] + window_timepoints = window["timepoints"] + window_masks = window["mask"] + window_labels = window["labels"] + + n_regions_per_frame, features = self.extract_embedding(window_masks, window_timepoints, window_labels, window_coords) + + for i in range(remaining or len(n_regions_per_frame)): + # if computing remaining frames' embeddings, start from the end + obj_per_frame = n_regions_per_frame[-i - 1] if remaining else n_regions_per_frame[i] + frame_index = start_index + i if not remaining else np.max(window_timepoints) - i + # logger.debug(f"Frame {frame_index} has {obj_per_frame} objects.") + all_frames_embeddings[frame_index] = features[:obj_per_frame] + features = features[obj_per_frame:] + + def extract_embedding(self, masks, timepoints, labels, coords): + # if masks.shape[-2:] != self.orig_image_size: + # This should not be occuring since each folder is loaded as a separate CTCData + # However when computing augmented embeddings in parallel, the input size may change + # logger.debug(f"Input shape change detected: {masks.shape[-2:]} from {self.orig_image_size}.") + # self.orig_image_size = masks.shape[-2:] + n_regions_per_frame = np.unique(timepoints, return_counts=True)[1] + tot_regions = n_regions_per_frame.sum() + coords_txy = np.concatenate((timepoints[:, None], coords), axis=-1) + if coords_txy.shape[0] != tot_regions: + raise RuntimeError(f"Number of coords ({coords_txy.shape[0]}) does not match the number of coordinates ({timepoints.shape[0]}).") + features = self.compute_region_features( + masks=masks, + coords=coords_txy, + timepoints=timepoints, + labels=labels, + ) + if torch.isnan(features).any(): + raise RuntimeError("NaN values found in features.") + if tot_regions != features.shape[0]: + raise RuntimeError(f"Number of regions ({n_regions_per_frame}) does not match the number of embeddings ({features.shape[0]}).") + return n_regions_per_frame, features + + @abstractmethod + def _run_model(self, images, **kwargs) -> torch.Tensor: # must return (B, grid_size**2, hidden_state_size) + """Extracts embeddings from the model.""" + pass + + def normalize_array(self, b): + b = percentile_norm(b) + if self.rescale_batches: + b = b * 255.0 + return b + + @staticmethod + def get_centroids_from_masks(masks: np.ndarray) -> np.ndarray: + """Computes the centroids of the objects in the masks. + + Args: + masks: (n_objects, H, W) array of masks. + + Returns: + Centroids: (n_objects, 2) array of (y, x) centroid coordinates, normalized to [0, 1]. + """ + centroids_df = regionprops(masks) + centroids = np.array([region.centroid for region in centroids_df]) + centroids[:, 0] = centroids[:, 0] / masks.shape[1] + centroids[:, 1] = centroids[:, 1] / masks.shape[2] + return centroids + + def apply_rot_to_features(self, features: torch.Tensor, centroids: np.ndarray) -> torch.Tensor: + """Applies a rotation to each feature vector based on the object's centroid. + + Args: + features: (n_objects, hidden_state_size) tensor of features. + centroids: (n_objects, 2) array of (y, x) centroid coordinates, normalized to [0, 1]. + + Returns: + Rotated features: (n_objects, hidden_state_size) + """ + n_objects, d = features.shape + assert d % 2 == 0, "Feature dimension must be even for rotation." + angle_x = torch.from_numpy(2 * np.pi * centroids[:, 0]).to(features.device) + angle_y = torch.from_numpy(2 * np.pi * centroids[:, 1]).to(features.device) + + angles = torch.stack([angle_x, angle_y], dim=1).repeat(1, d // 2) + angles = angles.view(n_objects, d) + cos = torch.cos(angles) + sin = torch.sin(angles) + features_ = features.view(n_objects, -1, 2) + x_feat, y_feat = features_[..., 0], features_[..., 1] + x_rot = x_feat * cos[:, ::2] - y_feat * sin[:, ::2] + y_rot = x_feat * sin[:, ::2] + y_feat * cos[:, ::2] + rotated = torch.stack([x_rot, y_rot], dim=-1).reshape(n_objects, d) + if torch.allclose(rotated, features): + logger.warning("Rotated features are equal to original features. Rotation may not be applied correctly.") + return rotated + + def _prepare_batches(self, images): + """Prepares batches of images for embedding extraction.""" + if self.do_normalize: + images = self.normalize_array(images) + if self.rescale_batches: + images = images * 255.0 + for i in range(0, len(images), self.batch_size): + end = i + self.batch_size + end = min(end, len(images)) + batch = np.expand_dims(images[i:end], axis=1) # (B, C, H, W) + + timepoints = range(i, end) + if self.n_channels > 1: # repeat channels if needed + if self.orig_n_channels > 1 and self.orig_n_channels != self.n_channels: + raise ValueError("When more than one original channel is provided, the number of channels in the model must match the number of channels in the input.") + batch = np.repeat(batch, self.n_channels, axis=1) + if not self.channel_first: + batch = np.moveaxis(batch, 1, -1) + if self.batch_return_type == "list[np.ndarray]": + batch = list([im for im in batch]) + yield timepoints, batch + + @staticmethod + def normalize_tensor(embeddings: torch.Tensor, norm: bool = True) -> torch.Tensor: + """Normalizes the embeddings by dividing by the norm.""" + if norm: + embeddings = embeddings / (embeddings.norm(dim=-1, keepdim=True) + 1e-8) + return embeddings + + def _map_coords_to_model_grid(self, coords): + scale_x = self.input_size[0] / self.orig_image_size[0] + scale_y = self.input_size[1] / self.orig_image_size[1] + coords = np.array(coords) + patch_x = (coords[:, 1] * scale_x).astype(int) + patch_y = (coords[:, 2] * scale_y).astype(int) + patch_coords = np.column_stack((coords[:, 0], patch_x, patch_y)) + return patch_coords + + def _find_nearest_cell(self, patch_coords): + """Finds the nearest cell in the grid for each patch coordinate.""" + x_idxs = patch_coords[:, 1] // self.model_patch_size[0] + y_idxs = patch_coords[:, 2] // self.model_patch_size[1] + patch_idxs = np.column_stack((patch_coords[:, 0], x_idxs, y_idxs)).astype(int) + return patch_idxs + + def _find_bbox_cells(self, regions: dict, cell_height: int, cell_width: int): + """Finds the cells in a grid that a bounding box belongs to. + + Args: + - regions (dict): Dictionary from regionprops. Must contain bbox. + - cell_height (int): Height of a cell in the grid. + - cell_width (int): Width of a cell in the grid. + + Returns: + - tuple: A tuple containing the grid cell indices that the bounding box intersects. + """ + mask_patches = {} + for region in regions: + minr, minc, maxr, maxc = region.bbox + patches = self._find_region_cells(minr, minc, maxr, maxc, cell_height, cell_width) + mask_patches[region.label] = patches + + return mask_patches + + @staticmethod + def _find_region_cells(minr, minc, maxr, maxc, cell_height, cell_width): + start_patch_y = minr // cell_height + end_patch_y = (maxr - 1) // cell_height + start_patch_x = minc // cell_width + end_patch_x = (maxc - 1) // cell_width + patches = np.array([(i, j) for i in range(start_patch_y, end_patch_y + 1) for j in range(start_patch_x, end_patch_x + 1)]) + return patches + + def _find_patches_for_masks(self, image_mask: np.ndarray) -> dict: + """Find which patches in a grid each mask belongs to using regionprops. + + Args: + - image_masks (np.ndarray): Masks where each region has a unique label. + + Returns: + - mask_patches (dict): Dictionary with region labels as keys and lists of patch indices as values. + """ + patch_height = image_mask.shape[0] // self.final_grid_size[0] + patch_width = image_mask.shape[1] // self.final_grid_size[1] + regions = regionprops(image_mask) + return self._find_bbox_cells(regions, patch_height, patch_width) + + def _debug_show_patches(self, embeddings, masks, coords, patch_idxs): + import napari + v = napari.Viewer() + # v.add_labels(masks) + e = embeddings.detach().cpu().numpy().swapaxes(1, 2) + e = e.reshape(-1, self.hidden_state_size, self.final_grid_size[0], self.final_grid_size[1]).swapaxes(0, 1) + + v.add_image( + e, + name="Embeddings", + ) + # add red points at patch indices for the relevant frame + points = np.zeros((len(patch_idxs) * self.hidden_state_size, 3)) + for i, (t, y, x) in enumerate(patch_idxs): + point = np.array([t, y, x]) + points[i * self.hidden_state_size:(i + 1) * self.hidden_state_size] = np.tile(point, (self.hidden_state_size, 1)) + + v.add_points(points, size=1, face_color='red', name='Patch Indices') + + from skimage.transform import resize + masks_resized = resize(masks[0], (self.final_grid_size[0], self.final_grid_size[1]), anti_aliasing=False, order=0, preserve_range=True) + v.add_labels(masks_resized) + logger.debug(f"Lost labels : {set(np.unique(masks)) - set(np.unique(masks_resized))}") + + napari.run() + + def _nearest_patches(self, coords, masks=None, norm=True, embs=None): + """Finds the nearest patches to the detections in the embedding.""" + # find coordinate patches from detections + patch_coords = self._map_coords_to_model_grid(coords) + patch_idxs = self._find_nearest_cell(patch_coords) + # logger.debug(f"Patch indices: {patch_idxs}") + + # load the embeddings and extract the relevant ones + feats = torch.zeros(len(coords), self.hidden_state_size, device=self.device) + indices = [y * self.final_grid_size[1] + x for _, y, x in patch_idxs] + unique_timepoints = list(set(t for t, _, _ in patch_idxs)) + # logger.debug(f"Unique timepoints: {unique_timepoints}") + embeddings = self._load_features() if embs is None else embs + + # try: + # t = coords[0][0] + # if t == 3: + # self._debug_show_patches(embeddings, masks, coords, patch_idxs) + # except IndexError: + # logger.debug("No timepoint found in coords.") + + # logger.debug(f"Embeddings shape: {embeddings.shape}") + embeddings_dict = {t: embeddings[t] for t in unique_timepoints} + try: + for i, (t, _, _) in enumerate(patch_idxs): + feats[i] = embeddings_dict[t][indices[i]] + except KeyError as e: + logger.error(f"KeyError: {e} - Check if the timepoint exists in embeddings_dict.") + except IndexError as e: + # TODO improve handling of this error. Maybe check shape earlier + logger.error(f"IndexError: {e} - Embeddings exist but do not have the correct shape. Did the model input size change ? If so, please delete saved embeddings and recompute.") + feats = FeatureExtractor.normalize_tensor(feats, norm=norm) + if self.apply_rope: + centroids = FeatureExtractor.get_centroids_from_masks(masks) + feats = self.apply_rot_to_features(feats, centroids) + return feats + + # @average_time_decorator + def _agg_patches_bbox(self, masks, timepoints, labels, agg=torch.mean, norm=True, embs=None): + """Averages the embeddings of all patches that intersect with the detection. + + Args: + - masks (np.ndarray): Masks where each region has a unique label (t x H x W). + - timepoints (np.ndarray): For each region, contains the corresponding timepoint. (n_regions) + - labels (np.ndarray): Unique labels of the regions. (n_regions) + - agg (callable): Aggregation function to use for averaging the embeddings. + """ + try: + n_regions = len(timepoints) + timepoints_shifted = timepoints - timepoints.min() + except ValueError: + logger.error("Error: issue computing shifted timepoints.") + logger.error(f"Regions: {len(timepoints)}") + logger.error(f"Timepoints: {timepoints}") + return torch.zeros(n_regions, self.hidden_state_size, device=self.device) + + feats = torch.zeros(n_regions, self.hidden_state_size, device=self.device) + patches = [] + times = np.unique(timepoints_shifted) + patches_res = joblib.Parallel(n_jobs=8, backend="threading")( + joblib.delayed(self._find_patches_for_masks)(masks[t]) for t in times + ) + patches = {t: patch for t, patch in zip(times, patches_res)} + # logger.debug(f"Patches : {patches}") + + embeddings = self._load_features() if embs is None else embs + + def process_region(i, t): + patches_feats = [] + for patch in patches[t][labels[i]]: + embs = embeddings[t][patch[1] * self.final_grid_size[1] + patch[0]] + embs = FeatureExtractor.normalize_tensor(embs, norm=norm) + patches_feats.append(embs) + aggregated = agg(torch.stack(patches_feats), dim=0) + # If agg is torch.max, extract only the values + if isinstance(aggregated, torch.return_types.max): + aggregated = aggregated.values + return aggregated + + res = joblib.Parallel(n_jobs=8, backend="threading")( + joblib.delayed(process_region)(i, t) for i, t in enumerate(timepoints_shifted) + ) + + for i, r in enumerate(res): + feats[i] = r + + return feats + + def _agg_patches_debug_view(self, v, region_mask, lab=None): + """Debug function to visualize the patches and their embeddings.""" + # Add region mask + v.add_labels( + region_mask, + name=f"Region Mask {lab}", + opacity=0.5, + blending="translucent", + ) + + def _view_embeddings(self, embeddings, context: dict | None = None, **kwargs): + # If this causes issues with augmented feature computation because + # the extractor grid size depends on image size, + # redefine it as appropriate in the subclass. + # Use context to pass information as needed for the use case. + + # Currently redefined in : + # - CoTrackerFeatures + embs = embeddings.view( + -1, self.final_grid_size[0], self.final_grid_size[1], self.hidden_state_size + ) + return embs, self.final_grid_size + + def _agg_patches_exact(self, masks, timepoints, labels, agg=torch.mean, norm=True, embs=None): + """Aggregates the embeddings of all patches that strictly belong to the mask.""" + try: + n_regions = len(timepoints) + timepoints_shifted = timepoints - timepoints.min() + except ValueError: + logger.error("Error: issue computing shifted timepoints.") + logger.error(f"Regions: {len(timepoints)}") + logger.error(f"Timepoints: {timepoints}") + return torch.zeros(n_regions, self.hidden_state_size, device=self.device) + + feats = torch.zeros(n_regions, self.hidden_state_size, device=self.device) + embeddings = self._load_features() if embs is None else embs + embeddings, grid_size = self._view_embeddings(embeddings, context={"masks_shape": masks.shape}) + + _T, H, W = masks.shape + # assert embeddings.shape[0] == _T, f"Embeddings times {embeddings.shape} does not match masks times {_T}." + grid_H, grid_W = grid_size + scale_y = grid_H / H + scale_x = grid_W / W + + if self._debug_view: + import napari + if napari.current_viewer() is None: + v = napari.Viewer() + else: + v = napari.current_viewer() + if "Masks" not in v.layers: + v.add_labels(masks[0], name="Masks") + if "Embeddings" not in v.layers: + embs = embeddings.view( + -1, self.final_grid_size[0], self.final_grid_size[1], self.hidden_state_size + ) + v.add_image( + embs.permute(3, 0, 1, 2).cpu().numpy(), + name="Embeddings", + colormap="inferno", + ) + + def process_region(i, t, masks): + if masks.shape[0] == 1: + masks = masks.squeeze(0) + mask_reg = masks == labels[i] + else: + mask_reg = masks[t] == labels[i] + if not np.any(mask_reg): + logger.warning(f"No pixels found for region {labels[i]} at timepoint {t}.") + # return torch.zeros(self.hidden_state_size, device=self.device) # small values to avoid zero divs etc. + return torch.fill(self.hidden_state_size, 1e-8, device=self.device) + + y_idxs, x_idxs = np.nonzero(mask_reg) + grid_y = np.clip((y_idxs * scale_y).astype(int), 0, grid_H - 1) + grid_x = np.clip((x_idxs * scale_x).astype(int), 0, grid_W - 1) + patch_embeddings = embeddings[timepoints[i]][grid_y, grid_x] + # normalizing before the mean seems most effective + patch_embeddings = FeatureExtractor.normalize_tensor(patch_embeddings, norm=norm) + + if self._debug_view: + mask_emb = np.zeros((grid_H, grid_W), dtype=np.uint16) + mask_emb[grid_y, grid_x] = labels[i] + self._agg_patches_debug_view(v, mask_emb, labels[i]) + + if patch_embeddings.shape[0] == 0: + logger.warning(f"No mapped pixels for region {labels[i]} at timepoint {t}.") + return torch.zeros(self.hidden_state_size, device=self.device) + return agg(patch_embeddings, dim=0) + + # Parallel processing + if self.parallel: + res = joblib.Parallel(n_jobs=8, backend="threading")( + joblib.delayed(process_region)(i, t, masks=masks) for i, t in enumerate(timepoints_shifted) + ) + for i, r in enumerate(res): + # If agg is torch.max or torch.median, extract only the values + if isinstance(r, torch.return_types.max) or isinstance(r, torch.return_types.median): + feats[i] = r.values + else: + feats[i] = r + else: + for i, t in enumerate(timepoints_shifted): + feats[i] = process_region(i, t, masks) + if self._debug_view: + napari.run() + # if norm: + # feats = feats / feats.norm(dim=-1, keepdim=True) + return feats + + def _exact_patch(self, masks, timepoints, labels, norm=False): + """Returns all embeddings overlapping with the mask of each object.""" + try: + n_regions = len(timepoints) + timepoints_shifted = timepoints - timepoints.min() + except ValueError: + logger.error("Error: issue computing shifted timepoints.") + logger.error(f"Regions: {len(timepoints)}") + logger.error(f"Timepoints: {timepoints}") + return torch.zeros(n_regions, self.hidden_state_size, device=self.device) + + feats = torch.zeros(n_regions, self.hidden_state_size, device=self.device) + embeddings = self._load_features() + embeddings = embeddings.view( + -1, self.final_grid_size[0], self.final_grid_size[1], self.hidden_state_size + ) + + _T, H, W = masks.shape + grid_H, grid_W = self.final_grid_size + scale_y = grid_H / H + scale_x = grid_W / W + + def process_region(i, t, masks): + if masks.shape[0] == 1: # single timepoint + masks = masks.squeeze(0) + mask_reg = masks == labels[i] + else: # all timepoints + mask_reg = masks[t] == labels[i] + if not np.any(mask_reg): + logger.warning(f"No pixels found for region {labels[i]} at timepoint {t}.") + return torch.zeros(self.hidden_state_size, device=self.device) + + y_idxs, x_idxs = np.nonzero(mask_reg) + grid_y = np.clip((y_idxs * scale_y).astype(int), 0, grid_H - 1) + grid_x = np.clip((x_idxs * scale_x).astype(int), 0, grid_W - 1) + patch_embeddings = embeddings[timepoints[i]][grid_y, grid_x] + if patch_embeddings.shape[0] == 0: + logger.warning(f"No mapped pixels for region {labels[i]} at timepoint {t}.") + return torch.zeros(self.hidden_state_size, device=self.device) + return patch_embeddings + + # Parallel processing + res = joblib.Parallel(n_jobs=8, backend="threading")( + joblib.delayed(process_region)(i, t, masks=masks) for i, t in enumerate(timepoints_shifted) + ) + for i, r in enumerate(res): + feats[i] = r + # for i, t in enumerate(timepoints_shifted): + # feats[i] = process_region(i, t, masks) + + if norm: + feats = feats / feats.norm(dim=-1, keepdim=True) + return feats + + def _save_features(self, features): # , timepoint): + """Saves the features to disk.""" + # save_path = self.save_path / f"{timepoint}_{self.model_name_path}_features.npy" + self.embeddings = features + if not self.do_save: + return + save_path = self.save_path / f"{self.model_name_path}_features.npy" + np.save(save_path, features.cpu().numpy()) + assert save_path.exists(), f"Failed to save features to {save_path}" + + def _load_features(self): # , timepoint): + """Loads the features from disk.""" + # load_path = self.save_path / f"{timepoint}_{self.model_name_path}_features.npy" + if self.embeddings is None: + load_path = self.save_path / f"{self.model_name_path}_features.npy" + if load_path.exists(): + features = np.load(load_path) + assert features is not None, f"Failed to load features from {load_path}" + if np.any(np.isnan(features)): + raise RuntimeError(f"NaN values found in features loaded from {load_path}.") + # check feature shape consistency + if features.shape[1] != self.final_grid_size[0] * self.final_grid_size[1] or features.shape[2] != self.hidden_state_size: + logger.error(f"Saved embeddings found, but shape {features.shape} does not match expected shape {('n_frames', self.final_grid_size[0] * self.final_grid_size[1], self.hidden_state_size)}.") + logger.error("Embeddings will be recomputed.") + return None + logger.info("Saved embeddings loaded.") + self.embeddings = torch.tensor(features).to(self.device) + return self.embeddings + else: + logger.info(f"No saved embeddings found at {load_path}. Features will be computed.") + return None + else: + return self.embeddings + + def _check_missing_embeddings(self): + """Checks if embeddings for the model already exist or are missing. + + Returns whether the embeddings need to be recomputed. + """ + if self.force_recompute: + return True + try: + features = self._load_features() + except FileNotFoundError: + return True + if features is None: + return True + else: + logger.info(f"Embeddings for {self.model_name} already exist. Skipping embedding computation.") + return False + + +class FeatureExtractorAugWrapper: + """Wrapper for the FeatureExtractor class to apply augmentations.""" + def __init__( + self, + extractor: FeatureExtractor, + augmenter: "PretrainedAugmentations", + n_aug: int = 1, + force_recompute: bool = False, + ): + self.extractor = extractor + self.additional_features = extractor.additional_features + self.n_aug = n_aug + self.aug_pipeline = augmenter + self.all_aug_features = {} # n_aug -> {aug_id: {metadata, data}} + # data -> {t: {lab: {"coords": coords, "features": features}}} + self.image_shape_reference = {} + + self.extractor.force_recompute = True + self.extractor.do_save = False # do not save intermediate features (augmented image embeddings) + # instead, we will save the augmented features + coordinates on a per-object basis in a zarr store + self.extractor.do_normalize = False + self.extractor.parallel = False + # already parallelized, faster this way since it avoids the overhead of spawning + # many small processes within the parallelized augmentation pipeline + + self._zarr_sync = zarr.ProcessSynchronizer(str(self.get_save_path()) + ".sync") + self.force_recompute = force_recompute + + self._debug_view = None + + def get_save_path(self): + root_path = self.extractor.save_path / "aug" + if not root_path.exists(): + root_path.mkdir(parents=True, exist_ok=True) + return root_path / f"{self.extractor.model_name_path}_aug.zarr" + + def _check_existing(self): + save_path = self.get_save_path() + if not save_path.exists() or self.force_recompute: + logger.debug(f"Augmentation zarr store {save_path} does not exist or force_recompute is True. Recomputing features.") + return False, None, None + logger.info(f"Loading existing features from {save_path}...") + # root = zarr.open_group(str(save_path), mode="r") + # existing_augs = [k for k in root.keys() if k.isdigit()] + features_dict = self.load_all_features() + existing_augs = list(features_dict.keys()) + logger.info("Done.") + return True, existing_augs, features_dict + + def _compute(self, images, masks): + """Computes the features for the images and masks.""" + images_shape = images.shape + if len(images_shape) != 3: + images_shape = images_shape[1:] # remove batch dimension if present + embs = self.extractor.precompute_image_embeddings(images, image_shape=images_shape) + + if self._debug_view is not None: + embs = embs.cpu().numpy() + logger.debug(f"Embeddings shape: {embs.shape}") + embs = embs.reshape(-1, self.extractor.final_grid_size[0], self.extractor.final_grid_size[1], self.extractor.hidden_state_size) + embs = np.moveaxis(embs, 3, 0) + self._debug_view.add_image(embs, name="Embeddings", colormap="inferno") + self._debug_view.add_image(images.cpu().numpy(), name="Images", colormap="viridis") + self._debug_view.add_labels(masks.cpu().numpy(), name="Masks") + + images, masks = images.cpu().numpy(), masks.cpu().numpy() + # features = wrfeat.WRAugPretrainedFeatures.from_mask_img( + # img=images, + # mask=masks, + # feature_extractor=self.extractor, + # t_start=0, + # additional_properties=self.extractor.additional_features, + # ) + features = [ + wrfeat.WRAugPretrainedFeatures.from_mask_img( + # embeddings=embs, + img=img[np.newaxis], + mask=mask[np.newaxis], + feature_extractor=self.extractor, + t_start=t, + additional_properties=self.extractor.additional_features + ) + for t, (mask, img) in tqdm( + enumerate(zip(masks, images)), desc="Computing features...", total=len(masks), leave=False + ) # if t == 10 debug + ] + features_dict = {t: v for f in features for t, v in f.to_dict().items()} + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return features_dict + + def _compute_original(self, images, masks): + """Computes the original features for the images and masks.""" + images, masks = self.aug_pipeline.preprocess(images, masks, normalize_func=self.extractor.normalize_array) + orig_feat_dict = self._compute(images, masks) + self.image_shape_reference[0] = images.shape[-2:] + return orig_feat_dict + + def _compute_augmented(self, images, masks, n): + images, masks = self.aug_pipeline.preprocess(images, masks, normalize_func=self.extractor.normalize_array) + aug_images, aug_masks, aug_record = self.aug_pipeline(images, masks) + + # check for NaNs + if torch.isnan(aug_images).any() or torch.isnan(aug_masks).any(): + raise RuntimeError("NaN values found in augmented images or masks.") + + im_shape, masks_shape = aug_images.shape, aug_masks.shape + assert im_shape == masks_shape, f"Augmented images shape {im_shape} does not match augmented masks shape {masks_shape}." + if im_shape[-2:] != self.extractor.orig_image_size: + # if isinstance(self.extractor, TAPFeatures): + # self.extractor.final_grid_size = (im_shape[-2], im_shape[-1]) # TAP features have same dims as images + if isinstance(self.extractor, CoTrackerFeatures): + stride = self.extractor.model.stride + self.extractor.final_grid_size = (im_shape[-2] // stride, im_shape[-1] // stride) + if im_shape[-1] == 0 or im_shape[-2] == 0: + raise ValueError(f"Augmented images have invalid shape {im_shape}. Cannot extract features.") + self.extractor.orig_image_size = im_shape[-2:] + + aug_feat_dict = self._compute(aug_images, aug_masks) + self.image_shape_reference[n] = aug_images.shape[-2:] + return aug_feat_dict, aug_record + + def _process_aug(self, n, images, masks, existing_aug_ids, existing_features_dict): + try: + if str(f"{n + 1}") in existing_aug_ids: + logger.info(f"Augmentation {n + 1} already exists. Skipping computation.") + aug_feat_dict = existing_features_dict[str(n + 1)]["data"] + aug_record = existing_features_dict[str(n + 1)]["metadata"] + else: + aug_feat_dict, aug_record = self._compute_augmented(images, masks, n=n + 1) + result = { + "n": n, + "metadata": aug_record, + "data": aug_feat_dict, + # "should_save": str(n + 1) not in existing_aug_ids, + } + if str(n + 1) not in existing_aug_ids: + self._save_features(n + 1, result) + return result + except Exception as e: + logger.error(f"Error processing augmentation {n + 1}: {e}") + raise e + + def compute_all_features(self, images, masks, clear_mem=True, n_workers=8) -> dict: + """Augments the images and masks, computes the embeddings, and saves features incrementally.""" + # check existing features + present, existing_augs, existing_features_dict = self._check_existing() + save_path = self.get_save_path() + if present: + logger.debug(f"Saved features found at {save_path}.") + if len(existing_augs) == self.n_aug + 1: + logger.info(f"All {self.n_aug} augmentations + original already exist. Loading existing features.") + self.all_aug_features = existing_features_dict + return self.all_aug_features + else: + logger.info("No existing augmentations found.") + logger.debug(f"Existing augmentations: {existing_augs}") + if existing_augs is None: + existing_aug_ids = [] + else: + existing_aug_ids = existing_augs + logger.debug(f"Existing augmentations IDs: {existing_aug_ids}") + + if "0" not in existing_aug_ids: + orig_feat_dict = self._compute_original(images, masks) + else: + logger.info("Original features already exist. Skipping computation for original features.") + orig_feat_dict = existing_features_dict["0"]["data"] + + self.all_aug_features = { + "0": { + "data": orig_feat_dict, + "metadata": { + "image_shape": images.shape, + } + } + } + if "0" not in existing_aug_ids: + self._save_features(0, self.all_aug_features["0"]) + + disable_parallel = False + if isinstance(self.extractor, CoTrackerFeatures) or isinstance(self.extractor, TAPFeatures) or isinstance(self.extractor, MicroSAMFeatures): + # CoTrackerFeatures and TAP uses a different grid size for each image, + # which requires a different approach to parallel processing. + # As a quick fix, parallel processing is disabled + # TODO make necessary changes to CoTrackerFeatures to allow parallel processing + disable_parallel = True + logger.debug(f"Disabling parallel processing for {self.extractor.__class__.__name__} due to variable grid size.") + + if n_workers == 0 or disable_parallel: + for n in range(self.n_aug): + res = self._process_aug(n, images, masks, existing_aug_ids, existing_features_dict) + self.all_aug_features[str(n + 1)] = res + # if res["should_save"]: + # self._save_features(n + 1, self.all_aug_features[str(n + 1)]) + else: + # joblib parallel processing + results = joblib.Parallel(n_jobs=n_workers, backend="threading")( + joblib.delayed(self._process_aug)( + n, images, masks, existing_aug_ids, existing_features_dict + ) for n in range(self.n_aug) + ) + for res in results: + n = res["n"] + self.all_aug_features[str(n + 1)] = { + "metadata": res["metadata"], + "data": res["data"], + } + # if res["should_save"]: + # self._save_features(n + 1, self.all_aug_features[str(n + 1)]) + + if clear_mem: + self.extractor.embeddings = self.extractor.embeddings.cpu() + try: + self.extractor.model = self.extractor.model.cpu() + except AttributeError as e: + logger.error(f"Model attribute not found: {e}. Skipping model transfer to CPU.") + self.extractor = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return self.all_aug_features + + # def _create_feat_dict(self, labels, ts, coords, features): + # """Creates a dictionary with the augmented features.""" + # aug_feat_dict = {} + # features = features.cpu().numpy() + # for i, (t, lab) in enumerate(zip(ts, labels)): + # t = int(t) + # lab = int(lab) + # if t not in aug_feat_dict: + # aug_feat_dict[t] = {} + # aug_feat_dict[t][lab] = { + # "coords": coords[i], + # "features": features[i], + # } + # return aug_feat_dict + + def _save_features(self, aug_id: int, aug_data: dict): + """Saves the features for a specific augmentation to disk as zarr (fast, flat layout).""" + import zarr + + save_path = self.get_save_path() + compressor = Blosc(cname='zstd', clevel=3, shuffle=Blosc.BITSHUFFLE) + root = zarr.open_group(str(save_path), mode="a", synchronizer=self._zarr_sync) + group_name = str(aug_id) + if group_name in root: + del root[group_name] + group = root.create_group(group_name) + group.attrs["metadata"] = json.dumps(aug_data.get("metadata", {})) + + coords_list = [] + t_list = [] + lab_list = [] + features_dict = {} + + for t, data in aug_data["data"].items(): + for lab, lab_data in data.items(): + coords_list.append(lab_data["coords"]) + t_list.append(t) + lab_list.append(lab) + for key, value in lab_data["features"].items(): + if key not in features_dict: + features_dict[key] = [] + features_dict[key].append(value) + + coords_arr = np.stack(coords_list) + t_arr = np.array(t_list) + lab_arr = np.array(lab_list) + group.create_dataset("coords", data=coords_arr, compressor=compressor) + group.create_dataset("timepoints", data=t_arr, compressor=compressor) + group.create_dataset("labels", data=lab_arr, compressor=compressor) + + features_group = group.create_group("features") + for key, values in features_dict.items(): + features_group.create_dataset(key, data=np.stack(values), compressor=compressor) + + # logger.debug(f"Augmented features for augmentation {aug_id} saved to {save_path}.") + + def load_all_features(self) -> dict: + """Loads all features from disk.""" + save_path = self.get_save_path() + if not save_path.exists(): + raise FileNotFoundError(f"Path {save_path} does not exist.") + + features = FeatureExtractorAugWrapper.load_features( + save_path, + additional_props=self.additional_features, + ) + self.all_aug_features = features + return features + + @staticmethod + def load_features(path: str | Path, additional_props: str | None = None, n_jobs: int = 12) -> dict: + """Loads the features for all augmentations from disk (flat zarr layout, parallelized).""" + import joblib + + if not isinstance(path, Path): + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Path {path} does not exist.") + + if additional_props is not None: + required_features = wrfeat._PROPERTIES[additional_props] + if len(required_features) == 0: + required_features = "pretrained_feats" + else: + required_features += ("pretrained_feats", ) + + root = zarr.open_group(str(path), mode="r") + aug_ids = [k for k in root.keys() if k.isdigit()] + + def _load_single(aug_id): + missing = False + if aug_id not in root: + missing = True + group = root[aug_id] + try: + metadata = json.loads(group.attrs["metadata"]) + except KeyError: + metadata = None + try: + coords_arr = group["coords"][...] + t_arr = group["timepoints"][...] + lab_arr = group["labels"][...] + features_group = group["features"] + features_dict = {} + for key in features_group.keys(): + if additional_props is not None and key not in required_features: + continue + features_dict[key] = features_group[key][...] + if additional_props is not None: + missing_keys = [k for k in required_features if k not in features_dict] + if missing_keys: + raise RuntimeError( + f"Missing required features {missing_keys} in augmentation {aug_id}. " + f"Please delete the cache at {path} and recompute the features." + ) + + data = {} + for i in range(len(t_arr)): + t = int(t_arr[i]) + lab = int(lab_arr[i]) + if t not in data: + data[t] = {} + feats = {k: features_dict[k][i] for k in features_dict} + data[t][lab] = { + "coords": coords_arr[i], + "features": feats, + } + return aug_id, {"metadata": metadata, "data": data} + except KeyError as e: + logger.error(f"KeyError: {e} - Augmentation {aug_id} is missing some data. Skipping.") + missing = True + if missing: + return aug_id, None + raise RuntimeError(f"Augmentation {aug_id} could not be loaded. Missing data or invalid format.") + + results = joblib.Parallel(n_jobs=n_jobs, backend="threading")( + joblib.delayed(_load_single)(aug_id) for aug_id in aug_ids + ) + all_data = {aug_id: aug_data for aug_id, aug_data in results if aug_data is not None} + return all_data + + +############## +@register_backbone("facebook/hiera-tiny-224-hf", 768) +class HieraFeatures(FeatureExtractor): + model_name = "facebook/hiera-tiny-224-hf" + + def __init__( + self, + image_size: tuple[int, int], + save_path: str | Path, + batch_size: int = 16, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + mode: PretrainedFeatsExtractionMode = "nearest_patch", + **kwargs, + ): + super().__init__(image_size, save_path, batch_size, device, mode, **kwargs) + # self.input_size = 224 + self.input_mul = 3 + self.input_size = int(self.input_mul * 224) + self.final_grid_size = int(7 * self.input_mul) # default is 7x7 grid + self.n_channels = 3 + self.hidden_state_size = 768 + self.rescale_batches = False + + ## + self.im_proc_kwargs["size"] = self.input_size + ## + self.image_processor = AutoImageProcessor.from_pretrained(self.model_name) + config = HieraConfig.from_pretrained(self.model_name) + config.image_size = [self.input_size[0], self.input_size[1]] + # logger.debug(f"Config: {config}") + # self.model = HieraModel.from_pretrained(self.model_name) + # self.model.config.image_size = [self.input_size, self.input_size] + self.model = HieraModel(config) + self.model.to(self.device) + # self.model.embeddings.patch_embeddings.num_patches = (self.input_size // self.model.config.patch_size[0]) ** 2 + # self.model.embeddings.position_embeddings = torch.nn.Parameter( + # torch.zeros(1, self.model.embeddings.patch_embeddings.num_patches + 1, self.hidden_state_size) + # ) + # self.model.embeddings.position_ids = torch.arange(0, self.model.embeddings.patch_embeddings.num_patches + 1).unsqueeze(0) + + def _run_model(self, images, **kwargs) -> torch.Tensor: + """Extracts embeddings from the model.""" + # images = self._normalize_batch(images) + inputs = self.image_processor(images, **self.im_proc_kwargs).to(self.device) + + with torch.no_grad(): + outputs = self.model(**inputs) + + return outputs.last_hidden_state + + +@register_backbone("facebook/dinov2-base", 768) +class DinoV2Features(FeatureExtractor): + model_name = "facebook/dinov2-base" + + def __init__( + self, + image_size: tuple[int, int], + save_path: str | Path, + batch_size: int = 16, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + mode: PretrainedFeatsExtractionMode = "nearest_patch", + **kwargs, + ): + super().__init__(image_size, save_path, batch_size, device, mode) + self.input_size = 224 + self.final_grid_size = 16 # 16x16 grid + self.n_channels = 3 # expects RGB images + self.hidden_state_size = 768 + self.image_processor = AutoImageProcessor.from_pretrained(self.model_name) + ## + self.im_proc_kwargs["size"] = self.input_size + ## + self.model = AutoModel.from_pretrained(self.model_name) + self.rescale_batches = False + # config = Dinov2Config.from_pretrained(self.model_name) + # config.image_size = self.input_size + + # self.model = Dinov2Model(config) + # logger.info(f"Model from config: {self.model.config}") + # logger.info(f"Pretrained model : {Dinov2Model.from_pretrained(self.model_name).config}") + self.model.to(self.device) + + def _run_model(self, images, **kwargs) -> torch.Tensor: + """Extracts embeddings from the model.""" + inputs = self.image_processor(images, **self.im_proc_kwargs).to(self.device) + + with torch.no_grad(): + outputs = self.model(**inputs) + + # ignore the CLS token (not classifying) + # this way we get only the patch embeddings + # which are compatible with finding the relevant patches directly + # in the rest of the code + return outputs.last_hidden_state[:, 1:, :] + + +@register_backbone("facebook/sam-vit-base", 256) +class SAMFeatures(FeatureExtractor): + model_name = "facebook/sam-vit-base" + + def __init__( + self, + image_size: tuple[int, int], + save_path: str | Path, + batch_size: int = 4, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + mode: PretrainedFeatsExtractionMode = "nearest_patch", + **kwargs, + ): + super().__init__(image_size, save_path, batch_size, device, mode) + self.input_size = 1024 + self.final_grid_size = 64 # 64x64 grid + self.n_channels = 3 + self.hidden_state_size = 256 + self.image_processor = SamProcessor.from_pretrained(self.model_name) + self.model = SamModel.from_pretrained(self.model_name) + self.rescale_batches = False + + self.model.to(self.device) + + def _run_model(self, images, **kwargs) -> torch.Tensor: + """Extracts embeddings from the model.""" + inputs = self.image_processor(images, **self.im_proc_kwargs).to(self.device) + outputs = self.model.get_image_embeddings(inputs['pixel_values']) + B, N, H, W = outputs.shape + return outputs.permute(0, 2, 3, 1).reshape(B, H * W, N) # (B, grid_size**2, hidden_state_size) + + +@register_backbone("facebook/sam2-hiera-large", 256) +@register_backbone("facebook/sam2.1-hiera-base-plus", 256) +class SAM2Features(FeatureExtractor): + model_name = "facebook/sam2.1-hiera-base-plus" + + def __init__( + self, + image_size: tuple[int, int], + save_path: str | Path, + batch_size: int = 4, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + mode: PretrainedFeatsExtractionMode = "nearest_patch", + **kwargs, + ): + super().__init__(image_size, save_path, batch_size, device, mode) + self.input_size = 1024 + self.final_grid_size = 64 # 64x64 grid + self.n_channels = 3 + self.hidden_state_size = 256 + self.model = SAM2ImagePredictor.from_pretrained(self.model_name, device=self.device).model + + self.batch_return_type = "list[np.ndarray]" + self.channel_first = True + self.rescale_batches = False + + self._bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + if self.rescale_batches: + print("Rescaling batches to [0, 255] range.") + + @torch.no_grad() + def _run_model(self, images: list[np.ndarray], **kwargs) -> torch.Tensor: + """Extracts embeddings from the model.""" + with torch.autocast(device_type=self.device), torch.inference_mode(): + images_ten = torch.stack([torch.tensor(image) for image in images]).to(self.device) + # logger.debug(f"Image dtype: {images_ten.dtype}") + # logger.debug(f"Image shape: {images_ten.shape}") + # logger.debug(f"Image min : {images_ten.min()}, max: {images_ten.max()}") + images_ten = F.interpolate(images_ten, size=(self.input_size[0], self.input_size[1]), mode="bilinear", align_corners=False) + # from torchvision.transforms.functional import resize + # images_ten = resize(images_ten, size=(self.input_size, self.input_size)) + # images_ten = normalize(images_ten, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + # out = self.model.image_encoder(images_ten) + out = self.model.forward_image(images_ten) + _, vision_feats, _, _ = self.model._prepare_backbone_features(out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model_name != "facebook/sam2.1-hiera-base-plus/highres": + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + feats = [ + feat.permute(1, 2, 0).view(feat.shape[1], -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + features = feats[-1] + # features = self.model.set_image_batch(images) + # features = self.model._features['image_embed'] + B, N, H, W = features.shape + return features.permute(0, 2, 3, 1).reshape(B, H * W, N) # (B, grid_size**2, hidden_state_size) + + +@register_backbone("facebook/sam2.1-hiera-base-plus/highres", 32) +class SAM2HighresFeatures(SAM2Features): + model_name = "facebook/sam2.1-hiera-base-plus" + + def __init__( + self, + image_size: tuple[int, int], + save_path: str | Path, + batch_size: int = 4, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + mode: PretrainedFeatsExtractionMode = "nearest_patch", + **kwargs, + ): + super().__init__(image_size, save_path, batch_size, device, mode) + self.final_grid_size = 256 # 256x256 grid + # self.final_grid_size = 128 + self.hidden_state_size = 32 + # self.hidden_state_size = 64 + + @property + def model_name_path(self): + """Returns the model name for saving.""" + p = f"{self.model_name}/highres" + return p.replace("/", "-") + + def _run_model(self, images: list[np.ndarray], **kwargs) -> torch.Tensor: + """Extracts embeddings from the model.""" + with torch.autocast(device_type=self.device), torch.inference_mode(): + images_ten = torch.stack([torch.tensor(image) for image in images]).to(self.device) + # logger.debug(f"Image dtype: {images_ten.dtype}") + # logger.debug(f"Image shape: {images_ten.shape}") + # logger.debug(f"Image min : {images_ten.min()}, max: {images_ten.max()}") + images_ten = F.interpolate(images_ten, size=(self.input_size[0], self.input_size[1]), mode="bilinear", align_corners=False) + out = self.model.forward_image(images_ten) + backbone_out, _, _, _ = self.model._prepare_backbone_features(out) + features = backbone_out["backbone_fpn"][0] # (B, N, H, W) + + B, N, H, W = features.shape + return features.permute(0, 2, 3, 1).reshape(B, H * W, N) # (B, grid_size**2, hidden_state_size) + + +@register_backbone("facebookresearch/co-tracker", 128) +class CoTrackerFeatures(FeatureExtractor): + model_name = "facebookresearch/co-tracker" + + def __init__( + self, + image_size: tuple[int, int], + save_path: str | Path, + batch_size: int = 4, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + mode: PretrainedFeatsExtractionMode = "nearest_patch", + **kwargs, + ): + super().__init__(image_size, save_path, batch_size, device, mode) + self.input_size = image_size + cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline") + self.model = cotracker.model + self.model.to(device) + self.final_grid_size = (image_size[0] // self.model.stride, image_size[1] // self.model.stride) + self.hidden_state_size = 128 + self.n_channels = 3 + self.fmaps_chunk_size = 8 + + self.batch_return_type = "list[np.ndarray]" + + def _view_embeddings(self, embeddings, context: dict, **kwargs): + image_shape = context.get("masks_shape", None) + H, W = image_shape[-2:] + grid_size = (H // self.model.stride, W // self.model.stride) + embs = embeddings.view( + -1, + *grid_size, + self.hidden_state_size, + ) + return embs, grid_size + + def precompute_image_embeddings(self, images, image_shape=None, **kwargs): # , windows, window_size): + """Precomputes embeddings for all images.""" + try: + if image_shape is None: + _, H, W = images.shape + if H != self.input_size[0] or W != self.input_size[1]: + self.input_size = (H, W) + self.final_grid_size = (H // self.model.stride, W // self.model.stride) + logger.debug(f"Updated CoTracker input size: {self.input_size}, final grid size: {self.final_grid_size}") + grid_size = self.final_grid_size + else: + H, W = image_shape[-2:] + grid_size = (H // self.model.stride, W // self.model.stride) + + missing = self._check_missing_embeddings() + all_embeddings = torch.zeros(len(images), grid_size[0] * grid_size[1], self.hidden_state_size, device=self.device) + if missing: + for ts, batches in tqdm(self._prepare_batches(images), total=len(images) // self.batch_size, desc="Computing embeddings", leave=False): + try: + embeddings = self._run_model(batches, image_shape, **kwargs) + except Exception as e: + breakpoint() + raise e + if torch.any(embeddings.isnan()): + raise RuntimeError("NaN values found in features.") + # logger.debug(f"Embeddings shape: {embeddings.shape}") + all_embeddings[ts] = embeddings.to(torch.float32) + assert embeddings.shape[-1] == self.hidden_state_size + self.embeddings = all_embeddings + self._save_features(all_embeddings) + # logger.debug(f"Precomputed embeddings shape: {self.embeddings.shape}") + return self.embeddings + except Exception as e: + logger.error(f"Error occurred while precomputing embeddings: {e}") + raise e + + def _run_model(self, images: list[np.ndarray], image_shape=None, **kwargs) -> torch.Tensor: + self.model.eval() + x = torch.stack([torch.tensor(image) for image in images]).to(self.device) + x = x.unsqueeze(0) # B, T, C, H, W + with torch.no_grad(): + B = x.shape[0] + T = x.shape[1] + C_ = x.shape[2] + H, W = x.shape[3], x.shape[4] + if T > self.batch_size: + fmaps = [] + for t in range(0, T, self.fmaps_chunk_size): + video_chunk = x[:, t : t + self.fmaps_chunk_size] + fmaps_chunk = self.model.fnet(video_chunk.reshape(-1, C_, H, W)) + T_chunk = video_chunk.shape[1] + C_chunk, H_chunk, W_chunk = fmaps_chunk.shape[1:] + fmaps.append(fmaps_chunk.reshape(B, T_chunk, C_chunk, H_chunk, W_chunk)) + fmaps = torch.cat(fmaps, dim=1).reshape(-1, C_chunk, H_chunk, W_chunk) + else: + fmaps = self.model.fnet(x.reshape(-1, C_, H, W)) + fmaps = fmaps.permute(0, 2, 3, 1) + fmaps = fmaps / torch.sqrt( + torch.maximum( + torch.sum(torch.square(fmaps), axis=-1, keepdims=True), + torch.tensor(1e-12, device=fmaps.device), + ) + ) + fmaps = fmaps.permute(0, 3, 1, 2).reshape( + B, -1, self.model.latent_dim, H // self.model.stride, W // self.model.stride + ) # B, T, N, H', W' + # end of original code + fmaps = fmaps.permute(0, 1, 3, 4, 2).squeeze(0) # T, H', W', N + fmaps = fmaps.reshape( + fmaps.shape[0], fmaps.shape[1] * fmaps.shape[2], fmaps.shape[3] + ) # T, H' * W', N + return fmaps + + +@register_backbone("microsam/vit_b_lm", 256) +@register_backbone("microsam/vit_l_lm", 256) +class MicroSAMFeatures(FeatureExtractor): + model_name = "microsam/vit_b_lm" + + def __init__( + self, + image_size: tuple[int, int], + save_path: str | Path, + batch_size: int = 4, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + mode: PretrainedFeatsExtractionMode = "nearest_patch", + **kwargs, + ): + if not MICRO_SAM_AVAILABLE: + raise ImportError("microSAM is not available. Please install it following the instructions in the documentation.") + super().__init__(image_size, save_path, batch_size, device, mode) + self.input_size = 1024 + self.final_grid_size = 64 + self.n_channels = 3 + self.hidden_state_size = 256 + model_name = self.model_name.split("/")[-1] + self.model = get_microsam_model(model_name, device=self.device) + + self.batch_return_type = "list[np.ndarray]" + self.channel_first = False + self.rescale_batches = False + + def _run_model(self, images: list[np.ndarray], **kwargs) -> torch.Tensor: + """Extracts embeddings from the model.""" + embeddings = [] + for image in images: + # logger.debug(f"Image shape: {image.shape}") + self.model.set_image(image) + embedding = self.model.get_image_embedding() # (1, hidden_state_size, grid_size, grid_size) + B, N, H, W = embedding.shape + embedding = embedding.permute(0, 2, 3, 1).reshape(B, H * W, N) + embeddings.append(embedding) + out = torch.stack(embeddings).squeeze() # (B, grid_size**2, hidden_state_size) + if len(out.shape) == 2: + out = out.unsqueeze(0) + return out + + +@register_backbone("weigertlab/tarrow", 32) +class TAPFeatures(FeatureExtractor): + model_name = "weigertlab/tarrow" + + def __init__( + self, + model_folder: str, + image_size: tuple[int, int], + save_path: str | Path, + batch_size: int = 2, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + mode: PretrainedFeatsExtractionMode = "nearest_patch", + **kwargs, + ): + super().__init__(image_size, save_path, batch_size, device, mode) + self._final_grid_size = (self.orig_image_size[0], self.orig_image_size[1]) + self.input_size = self.final_grid_size + self.n_channels = 1 + + self.model_folder = model_folder + self.full_model, self.hidden_state_size = self._load_model_from_path(model_folder) + self.full_model.to(device) + self.full_model.eval() + AVAILABLE_PRETRAINED_BACKBONES["weigertlab/tarrow"]["feat_dim"] = self.hidden_state_size + self.model = self.full_model.backbone + + self.batch_return_type = "torch.Tensor" + self.channel_first = False + self.rescale_batches = False + self.normalize_embeddings = False + + # TODO clear full model from memory + + @property + def final_grid_size(self) -> tuple[int, int]: + return self._final_grid_size + + @final_grid_size.setter + def final_grid_size(self, value: tuple[int, int]): + """Sets the final grid size and updates the model's input size.""" + self._final_grid_size = value + self.orig_image_size = value + self.input_size = value + self.model.input_size = value + self.model_patch_size = (1, 1) + + def _set_model_patch_size(self): + pass + + @staticmethod + def _load_model_from_path(model_folder: str): + """Loads the model from the folder.""" + if not os.path.exists(model_folder): + raise FileNotFoundError(f"Model folder {model_folder} does not exist.") + model = TimeArrowNet.from_folder(model_folder, from_state_dict=True) + feat_dim = model.bb_n_feat + return model, feat_dim + + def normalize_array(self, b): + images_batch = tap_normalize(b) + images_batch = torch.from_numpy(images_batch).to(torch.float32) # T, H, W + return images_batch + + def _prepare_batches(self, images): + if self.do_normalize: + images = self.normalize_array(images) + images = images.unsqueeze(1) # T, C, H, W + return images + + def _run_model(self, images: list[np.ndarray], **kwargs) -> torch.Tensor: + """Extracts embeddings from the model.""" + features = [] + ts = images.shape[0] + im_shape = tuple(images.shape[-2:]) + self.orig_image_size = im_shape + self.final_grid_size = im_shape + with torch.no_grad(): + for i in tqdm(range(0, len(images), self.batch_size), desc="Computing TAP features", leave=False): + batch = images[i : i + self.batch_size] + batch = batch.to(self.device) + out = self.model(batch) + features.append(out) + + features = torch.cat(features, dim=0).cpu() + + if self.device == "cuda": + torch.cuda.empty_cache() + + features = features.moveaxis(1, 3) # (T, H, W, N) + features = features.reshape(ts, self.final_grid_size[0] * self.final_grid_size[1], self.hidden_state_size) # (T, grid_size**2, N) + return features + + def precompute_image_embeddings(self, images, **kwargs): + # missing = self._check_missing_embeddings() + if images.shape[-2:] != self.orig_image_size: + self.orig_image_size = images.shape[-2:] + self.final_grid_size = images.shape[-2:] + batches = self._prepare_batches(images) + self.embeddings = self._run_model(batches, **kwargs) + # self._save_features(self.embeddings) + return self.embeddings + + +@register_backbone("mouseland/cellpose-sam", 192) +class CellposeSAMFeatures(FeatureExtractor): + model_name = "mouseland/cellpose-sam" + + def __init__( + self, + image_size: tuple[int, int], + save_path: str | Path, + batch_size: int = 4, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + mode: PretrainedFeatsExtractionMode = "nearest_patch", + **kwargs, + ): + if not CELLPOSE_AVAILABLE: + raise ImportError("Cellpose is not available. Please install it following the instructions in the documentation.") + super().__init__(image_size, save_path, batch_size, device, mode) + self.input_size = 256 + self.final_grid_size = 32 # 32x32 grid + self.n_channels = 3 + self.hidden_state_size = 192 + self.model = CellposeSAM() + self.model.to(self.device) + self.model.eval() + + self.batch_return_type = "np.ndarray" + self.channel_first = False + self.rescale_batches = False + + def normalize_array(self, images_batch): + batch = torch.zeros( + (*tuple(images_batch.shape), self.n_channels), # add a channel dimension + dtype=torch.float32 + ) + for n, b in enumerate(images_batch): + if isinstance(b, torch.Tensor): + b = b.cpu().numpy() + b_ = cp_transforms.convert_image( + b, + channel_axis=None, + z_axis=None, + do_3D=False + ) + b_ = cp_transforms.normalize_img(b_) + b_ = torch.from_numpy(b_).to(torch.float32) + batch[n] = b_ + logger.debug(f"Cellpose SAM batch shape: {batch.shape}") + return batch[..., 0] # keep only a single copy of the channel + + def _prepare_batches(self, images): + if self.do_normalize: + images = self.normalize_array(images) + for i in range(0, len(images), self.batch_size): + end = i + self.batch_size + end = min(end, len(images)) + batch = images[i:end] # (B, H, W) + ts = range(i, end) + # if self.do_normalize: + # batch = self.normalize_array(batch) + if len(batch.shape) == 3: + batch = batch.unsqueeze(1) + batch = batch.repeat(1, 3, 1, 1) + batch = batch.to(self.device) + yield ts, batch + + def _run_model(self, images_batch: np.ndarray, **kwargs) -> torch.Tensor: + embeddings = [] + + with torch.no_grad(): + + b = F.interpolate(images_batch, size=(self.input_size[0], self.input_size[1]), mode="bilinear", align_corners=False) + + x = self.model.encoder.patch_embed(b) + if self.model.encoder.pos_embed is not None: + x = x + self.model.encoder.pos_embed + for i, blk in enumerate(self.model.encoder.blocks): + x = blk(x) + x = self.model.encoder.neck(x.permute(0, 3, 1, 2)) + x = self.model.out(x) # (B, N, H, W) + embeddings.append(x) + + embeddings = torch.cat(embeddings, dim=0) # (T, N, H, W) + embeddings = embeddings.moveaxis(1, 3) # (T, H, W, N) + embeddings = embeddings.reshape(-1, self.final_grid_size[0] * self.final_grid_size[1], self.hidden_state_size) # (T, grid_size**2, N) + return embeddings + + +@register_backbone("debug/encoded_labels", 64) +class EncodedLabelsFeatures(FeatureExtractor): + """Encodes labels to 32 dimensions. Should work as a "perfect" feature extractor that uses GT labels as a sanity check.""" + model_name = "debug/encoded_labels" + + def __init__( + self, + image_size: tuple[int, int], + save_path: str | Path, + batch_size: int = 4, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + mode: PretrainedFeatsExtractionMode = "nearest_patch", + **kwargs, + ): + super().__init__(image_size, save_path, batch_size, device, mode) + self.input_size = 1024 + self.final_grid_size = 32 + self.n_channels = 1 + self.hidden_state_size = 64 + + def _run_model(self, images, **kwargs) -> torch.Tensor: + """Extracts embeddings from the model.""" + pass + + def precompute_image_embeddings(self, images, **kwargs): + pass + + def _encode_labels(self, labels): + """Encodes the labels to D dimensions.""" + features = np.zeros((labels.shape[0], self.hidden_state_size), dtype=np.float32) + for i in range(labels.shape[0]): + label = labels[i] + features[i] = np.array([int(x) for x in np.binary_repr(label, width=self.hidden_state_size)], dtype=np.float32) + + features = torch.from_numpy(features).to(self.device) + return features + + def compute_region_features(self, labels=None, **kwargs): + return self._encode_labels(labels) # (n_labels, self.hidden_state_size) + + +@register_backbone("debug/random", 128) +class RandomFeatures(FeatureExtractor): + model_name = "debug/random" + + def __init__( + self, + image_size: tuple[int, int], + save_path: str | Path, + batch_size: int = 4, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + mode: PretrainedFeatsExtractionMode = "nearest_patch", + seed: int = 42, + **kwargs, + ): + super().__init__(image_size, save_path, batch_size, device, mode) + self.input_size = 1024 + # self.final_grid_size = self.orig_image_size + self.final_grid_size = 128 + self.n_channels = 3 + self.hidden_state_size = 128 + + self._seed = seed if seed is not None else 42 + self.device = "cpu" + self._generator = torch.Generator(device="cpu").manual_seed(self._seed) + + self.do_save = False + + def _run_model(self, images, **kwargs) -> torch.Tensor: + """Extracts embeddings from the model.""" + # Normal distribution + # return torch.randn(len(images), self.final_grid_size**2, self.hidden_state_size, generator=self._generator).to(self.device) + # Uniform distribution + feats = torch.rand( + len(images), + self.final_grid_size[0] * self.final_grid_size[1], + self.hidden_state_size, + generator=self._generator, + dtype=torch.float32 + ) + # feats = feats * 4 - 2 # [-2, 2] + return feats.to("cpu") + + +FeatureExtractor._available_backbones = AVAILABLE_PRETRAINED_BACKBONES + + +# Embeddings post-processing + +import pickle + +from sklearn.decomposition import PCA + + +class EmbeddingsPCACompression: + + def __init__(self, original_model_name: str, n_components: int = 15, save_path: str | Path | None = None): + self.original_model_name = original_model_name + self.n_components = n_components + self.save_path = save_path / "pca_model.pkl" if save_path is not None else None + self.pca = PCA(n_components=n_components) + self.max_frames = 1500 + self.random_state = 42 + self.pca.random_state = self.random_state + self.generator = np.random.default_rng(self.random_state) + + @classmethod + def from_pretrained_cfg(cls, cfg: PretrainedFeatureExtractorConfig): + return cls( + original_model_name=cfg.model_name.replace("/", "-"), + n_components=cfg.pca_components, + save_path=cfg.pca_preprocessor_path + ) + + def fit(self, X: np.ndarray): + """Fits the PCA model to the embeddings.""" + self.pca.fit(X) + + if self.save_path is not None: + if not self.save_path.parents[0].exists(): + self.save_path.parents[0].mkdir(parents=True, exist_ok=True) + with open(self.save_path, 'wb') as f: + pickle.dump(self.pca, f) + + def transform(self, X: np.ndarray) -> np.ndarray: + """Transforms the embeddings using the fitted PCA model.""" + return self.pca.transform(X) + + def load_from_file(self, path: str | Path): + """Loads the PCA model from a file.""" + if isinstance(path, str): + path = Path(path) + path = path / "pca_model.pkl" if path.suffix != ".pkl" else path + with open(path, 'rb') as f: + self.pca = pickle.load(f) + logger.info(f"Loaded PCA model from {path}.") + + def fit_on_embeddings(self, embeddings_source_folders: list[str | Path]): + """Fits the PCA model to the embeddings loaded from a file.""" + embeddings = [] + N_samples = 0 + + embeddings_paths = [] + for folder in embeddings_source_folders: + for file in Path(folder).rglob("*.npy"): + if self.original_model_name in file.name: + embeddings_paths.append(file) + + if len(embeddings_paths) == 0: + return + + embeddings_paths = self.generator.permutation(embeddings_paths) + logger.info(f"Fitting PCA model on {len(embeddings_paths)} embeddings files.") + logger.info("Files :") + for p in embeddings_paths: + logger.info(f" - {p}") + logger.info("*" * 50) + + for path in embeddings_paths: + if isinstance(path, str): + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File {path} does not exist.") + emb = np.load(path) + N_samples += emb.shape[0] + if N_samples > self.max_frames: + logger.info(f"Amount of loaded frames exceeds {self.max_frames} limit for PCA computation.") + break + else: + embeddings.append(emb) + + embeddings = np.concatenate(embeddings, axis=0) + embeddings = embeddings.reshape(-1, embeddings.shape[-1]) + self.fit(embeddings) + logger.info(f"Fitted PCA model with {self.n_components} components on {N_samples} frames.") + \ No newline at end of file diff --git a/trackastra/data/utils.py b/trackastra/data/utils.py index a451876..5a37e78 100644 --- a/trackastra/data/utils.py +++ b/trackastra/data/utils.py @@ -1,6 +1,6 @@ import logging import sys -from pathlib import Path +from pathlib import Path, WindowsPath import numpy as np import pandas as pd @@ -9,9 +9,25 @@ import tifffile from tqdm import tqdm +from trackastra.data.pretrained_features import PretrainedFeatureExtractorConfig + logger = logging.getLogger(__name__) +def make_hashable(obj): + if isinstance(obj, tuple | list): + return tuple(make_hashable(e) for e in obj) + elif isinstance(obj, Path | WindowsPath): + return obj.as_posix() + elif isinstance(obj, dict): + return tuple(sorted((k, make_hashable(v)) for k, v in obj.items())) + elif isinstance(obj, PretrainedFeatureExtractorConfig): + cfg_dict = obj.to_dict() + return make_hashable(cfg_dict) + else: + return obj + + def load_tiff_timeseries( dir: Path, dtype: str | type | None = None, diff --git a/trackastra/data/wrfeat.py b/trackastra/data/wrfeat.py index 2c72af7..3c57a5f 100644 --- a/trackastra/data/wrfeat.py +++ b/trackastra/data/wrfeat.py @@ -2,12 +2,15 @@ WindowedRegionFeatures (WRFeatures) is a class that holds regionprops features for a windowed track region. """ +from __future__ import annotations + import itertools import logging +from abc import ABC, abstractmethod from collections import OrderedDict from collections.abc import Iterable, Sequence from functools import reduce -from typing import Literal +from typing import TYPE_CHECKING, ClassVar, Literal import joblib import numpy as np @@ -16,9 +19,12 @@ from skimage.measure import regionprops, regionprops_table from tqdm import tqdm -from trackastra.data.utils import load_tiff_timeseries +if TYPE_CHECKING: + + from trackastra.data.pretrained_features import FeatureExtractor logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) _PROPERTIES = { "regionprops": ( @@ -34,7 +40,12 @@ "inertia_tensor", "border_dist", ), + "regionprops_small": ( + "area", + "inertia_tensor", + ), } +DEFAULT_PROPERTIES = "regionprops2" def _filter_points( @@ -67,8 +78,56 @@ def _border_dist(mask: np.ndarray, cutoff: float = 5): return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist)) +def _border_dist_fast(mask: np.ndarray, cutoff: float = 5): + cutoff = int(cutoff) + border = np.ones(mask.shape, dtype=np.float32) + ndim = len(mask.shape) + + for axis, size in enumerate(mask.shape): + # Create fade values for the band [0, cutoff) + band_vals = np.arange(cutoff, dtype=np.float32) / cutoff + + # Build slices for the low border + low_slices = [slice(None)] * ndim + low_slices[axis] = slice(0, cutoff) + border_low = border[tuple(low_slices)] + border_low_vals = np.minimum( + border_low, band_vals[(...,) + (None,) * (ndim - axis - 1)] + ) + border[tuple(low_slices)] = border_low_vals + + # Build slices for the high border + high_slices = [slice(None)] * ndim + high_slices[axis] = slice(size - cutoff, size) + band_vals_rev = band_vals[::-1] + border_high = border[tuple(high_slices)] + border_high_vals = np.minimum( + border_high, band_vals_rev[(...,) + (None,) * (ndim - axis - 1)] + ) + border[tuple(high_slices)] = border_high_vals + + dist = 1 - border + return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist)) + +# Features classes + + class WRFeatures: """regionprops features for a windowed track region.""" + PROPERTIES_DIMS: ClassVar = { + "regionprops": { + 2: 8, + 3: 12, + }, + "regionprops2": { + 2: 7, + 3: 12, + }, + "regionprops_small": { + 2: 5, + 3: 9, + }, + } def __init__( self, @@ -76,6 +135,7 @@ def __init__( labels: np.ndarray, timepoints: np.ndarray, features: OrderedDict[np.ndarray], + properties: str = DEFAULT_PROPERTIES, ): self.ndim = coords.shape[-1] if self.ndim not in (2, 3): @@ -83,9 +143,14 @@ def __init__( self.coords = coords self.labels = labels - self.features = features.copy() + if features is None: + self.features = OrderedDict() + else: + self.features = features.copy() self.timepoints = timepoints - + + self.properties = properties + def __repr__(self): s = ( f"WindowRegionFeatures(ndim={self.ndim}, nregions={len(self.labels)}," @@ -97,7 +162,25 @@ def __repr__(self): @property def features_stacked(self): - return np.concatenate([v for k, v in self.features.items()], axis=-1) + if not self.features or (len(self.features) == 1 and "pretrained_feats" in self.features): + # logger.warning("No features to stack") + return None + feats = np.concatenate( + [v for k, v in self.features.items() if k != "pretrained_feats"], + axis=-1 + ) + # raise if any NaNs in features + return feats + + @property + def pretrained_feats(self): + # for compatibility with WRPretrainedFeatures + if "pretrained_feats" in self.features: + return self.features["pretrained_feats"] + # return self.features["pretrained_feats"] / np.linalg.norm( + # self.features["pretrained_feats"], axis=-1, keepdims=True + # ) + return None def __len__(self): return len(self.labels) @@ -107,15 +190,22 @@ def __getitem__(self, key): return self.features[key] else: raise KeyError(f"Key {key} not found in features") + + @property + def features_dims(self): + """Returns the number of features for each property.""" + if self.properties not in self.PROPERTIES_DIMS: + raise ValueError(f"Unknown feature type {self.properties}") + return self.PROPERTIES_DIMS[self.properties][self.ndim] @classmethod - def concat(cls, feats: Sequence["WRFeatures"]) -> "WRFeatures": + def concat(cls, feats: Sequence[WRFeatures]) -> WRFeatures: """Concatenate multiple WRFeatures into a single one.""" if len(feats) == 0: raise ValueError("Cannot concatenate empty list of features") return reduce(lambda x, y: x + y, feats) - def __add__(self, other: "WRFeatures") -> "WRFeatures": + def __add__(self, other: WRFeatures) -> WRFeatures: """Concatenate two WRFeatures.""" if self.ndim != other.ndim: raise ValueError("Cannot concatenate features of different dimensions") @@ -131,23 +221,24 @@ def __add__(self, other: "WRFeatures") -> "WRFeatures": for k, v in self.features.items() ) - return WRFeatures( + return self.__class__( coords=coords, labels=labels, timepoints=timepoints, features=features ) - - @classmethod - def from_mask_img( - cls, - mask: np.ndarray, - img: np.ndarray, - properties="regionprops2", - t_start: int = 0, - ): + + @staticmethod + def get_regionprops_features(properties, mask, img, t_start=0): + """Extracts regionprops features from a mask and image.""" + img = np.asarray(img) + mask = np.asarray(mask) _ntime, ndim = mask.shape[0], mask.ndim - 1 if ndim not in (2, 3): raise ValueError("Only 2D or 3D data is supported") - properties = tuple(_PROPERTIES[properties]) + if properties is None: + properties = () + else: + properties = tuple(_PROPERTIES[properties]) + if "label" in properties or "centroid" in properties: raise ValueError( f"label and centroid should not be in properties {properties}" @@ -169,7 +260,7 @@ def from_mask_img( _df["timepoint"] = i + t_start if use_border_dist: - _df["border_dist"] = _border_dist(y) + _df["border_dist"] = _border_dist_fast(y) dfs.append(_df) df = pd.concat(dfs) @@ -180,6 +271,24 @@ def from_mask_img( timepoints = df["timepoint"].values.astype(np.int32) labels = df["label"].values.astype(np.int32) coords = df[[f"centroid-{i}" for i in range(ndim)]].values.astype(np.float32) + + # if any NaNs in features, raise + if df.isnull().values.any(): + raise ValueError("NaNs found in features DataFrame") + + return df, coords, labels, timepoints, properties + + @classmethod + def from_mask_img( + cls, + mask: np.ndarray, + img: np.ndarray, + properties: str = DEFAULT_PROPERTIES, + t_start: int = 0, + ): + df, coords, labels, timepoints, properties = cls.get_regionprops_features( + properties, mask, img, t_start=t_start + ) features = OrderedDict( ( @@ -197,26 +306,179 @@ def from_mask_img( ) return cls( - coords=coords, labels=labels, timepoints=timepoints, features=features + coords=coords, labels=labels, timepoints=timepoints, features=features, properties=properties ) -# augmentations +class WRPretrainedFeatures(WRFeatures): + """WindowedRegion with features from pre-trained models.""" + + def __init__( + self, + coords: np.ndarray, + labels: np.ndarray, + timepoints: np.ndarray, + features: OrderedDict[np.ndarray], + additional_properties: str | None = None + ): + super().__init__(coords, labels, timepoints, features) + self.additional_properties = additional_properties + + @property + def features_stacked(self): + if not self.features or (len(self.features) == 1 and "pretrained_feats" in self.features): + # logger.warning("No features to stack") + return None + feats = np.concatenate( + [v for k, v in self.features.items() if k != "pretrained_feats"], + axis=-1 + ) + # raise if any NaNs in features + return feats + + @property + def pretrained_feats(self): + return super().pretrained_feats + # if "pretrained_feats" in self.features: + # return self.features["pretrained_feats"] + # return None + + @classmethod + def from_mask_img( + cls, + img: np.ndarray, + mask: np.ndarray, + feature_extractor: FeatureExtractor, + t_start: int = 0, + additional_properties: str | None = None, + # embeddings: torch.Tensor | None = None, + ) -> WRPretrainedFeatures: + + ndim = img.ndim - 1 + if ndim != 2: + raise ValueError("Only 2D data is supported") + + df, coords, labels, timepoints, properties = cls.get_regionprops_features( + additional_properties, mask, img, t_start=t_start + ) + # if embeddings is None: + _, features = feature_extractor.extract_embedding(mask, timepoints, labels, coords) + # else: + # _, features = feature_extractor.extract_embedding(mask, timepoints, labels, coords, embs=embeddings) + features = features.detach().cpu().numpy() + feats_dict = OrderedDict(pretrained_feats=features) + # Add additional features similarly to WRFeatures if any + if additional_properties is not None: + for p in properties: + feats_dict[p] = np.stack( + [ + df[c].values.astype(np.float32) + for c in df.columns + if c.startswith(p) + ], + axis=-1, + ) + + return cls( + coords=coords, labels=labels, timepoints=timepoints, features=feats_dict, additional_properties=additional_properties + ) + + +class WRAugPretrainedFeatures(WRPretrainedFeatures): + + def __init__( + self, + coords: np.ndarray, + labels: np.ndarray, + timepoints: np.ndarray, + features: OrderedDict[np.ndarray], + additional_properties: str | None = None, + ): + + super().__init__(coords, labels, timepoints, features, additional_properties) + self.ndim = coords.shape[-1] + if self.ndim != 2: + raise ValueError("Only 2D data is supported") + + def __len__(self): + return super().__len__() + + @classmethod + def from_window(cls, features, coords, timepoints, labels): + """Build a WRAugPretrainedFeatures from a window. + + Args: + features (np.ndarray): The features to use. + coords (np.ndarray): The coordinates to use. + timepoints (np.ndarray): The timepoints to use. + labels (np.ndarray): The labels to use. + """ + # coords = coords[:, 1:] + return cls( + coords=coords, + labels=labels, + timepoints=timepoints, + features=features, + ) + + def to_window(self): + """Convert the features to a window.""" + coords = np.concatenate((self.timepoints[:, None], self.coords), axis=-1) + + if len(self.features) == 1 and "pretrained_feats" in self.features.keys(): + feats = None + else: + feats = np.concatenate( + [v for k, v in self.features.items() if k != "pretrained_feats"], + axis=-1 + ) + pretrained_feats = self.features["pretrained_feats"] + return feats, pretrained_feats, coords, self.timepoints, self.labels + + def to_dict(self): + """Convert the features to a dictionary.""" + res = {} + for i, (t, lab) in enumerate(zip(self.timepoints, self.labels)): + t = int(t) + lab = int(lab) + if t not in res: + res[t] = {} + feat_dict = {} + for k, v in self.features.items(): + feat_dict[k] = v[i] + res[t][lab] = { + "coords": self.coords[i], + "features": dict(feat_dict), + } + return res + + +# Augmentations class WRRandomCrop: - """windowed region random crop augmentation.""" + """windowed region random crop augmentation. + + Affected properties: + - "coords" + - "labels" + - "timepoints" + - "features" + """ + return_type = WRFeatures def __init__( self, crop_size: int | tuple[int] | None = None, ndim: int = 2, + return_type=WRFeatures, ) -> None: """crop_size: tuple of int can be tuple of length 1 (all dimensions) of length ndim (y,x,...) of length 2*ndim (y1,y2, x1,x2, ...). """ + self.return_type = return_type if isinstance(crop_size, int): crop_size = (crop_size,) * 2 * ndim elif isinstance(crop_size, Iterable): @@ -258,45 +520,84 @@ def __call__(self, features: WRFeatures): ) idx = _filter_points(points, shape=crop_size, origin=corner) - + feats = OrderedDict( + (k, v[idx]) for k, v in features.features.items() + ) return ( - WRFeatures( + self.return_type( coords=points[idx], labels=features.labels[idx], timepoints=features.timepoints[idx], - features=OrderedDict((k, v[idx]) for k, v in features.features.items()), + features=feats, ), idx, ) -class WRBaseAugmentation: +class WRBaseAugmentation(ABC): + """Base class for windowed region augmentations.""" + return_type = WRFeatures + def __init__(self, p: float = 0.5) -> None: self._p = p self._rng = np.random.RandomState() def __call__(self, features: WRFeatures): + # logger.debug(f"Before augmentation: {self.__class__.__name__}, return_type={self.return_type}") if self._rng.rand() > self._p or len(features) == 0: return features - return self._augment(features) + feats = self._augment(features) + # logger.debug(f"After augmentation: {self.__class__.__name__}, return_type={self.return_type}") + self.check_features(feats) + return feats + @abstractmethod def _augment(self, features: WRFeatures): raise NotImplementedError() + + def check_features(self, features: WRFeatures): + """Check if features are valid.""" + if not isinstance(features, self.return_type): + raise ValueError(f"Expected {self.return_type}, got {type(features)}") + if len(features) == 0: + raise ValueError("Empty features") + + for k, f in features.features.items(): + if np.any(np.isnan(f)): + logger.warning(f"NaNs found in {k} after {self.__class__.__name__} augmentation") + if np.any(np.isinf(f)): + logger.warning(f"Infs found in {k} after {self.__class__.__name__} augmentation") + # if np.all(np.all(f == 0, axis=-1)): + # logger.warning(f"Empty {k} after {self.__class__.__name__} augmentation") class WRRandomFlip(WRBaseAugmentation): + """Random flip augmentation. + + Affected properties: + - "area" + - "equivalent_diameter_area" + - "inertia_tensor" + """ def _augment(self, features: WRFeatures): ndim = features.ndim flip = self._rng.randint(0, 2, features.ndim) + M = np.eye(ndim) points = features.coords.copy() for i, f in enumerate(flip): if f == 1: points[:, ndim - i - 1] *= -1 - return WRFeatures( + M[i, i] = -1 + + feats = OrderedDict( + (k, _transform_affine(k, v, M)) for k, v in features.features.items() + ) + + return self.return_type( coords=points, labels=features.labels, timepoints=features.timepoints, - features=features.features, + features=feats, ) @@ -323,18 +624,22 @@ def _rotation_matrix(theta: float): def _transform_affine(k: str, v: np.ndarray, M: np.ndarray): ndim = len(M) + if k == "area": - v = np.linalg.det(M) * v + v = np.abs(np.linalg.det(M)) * v elif k == "equivalent_diameter_area": - v = np.linalg.det(M) ** (1 / len(M)) * v - + # v = np.linalg.det(M) ** (1 / len(M)) * v + v = np.abs(np.linalg.det(M)) ** (1 / len(M)) * v + # TODO check the behavior of equivalent_diameter_area in 3D regionprops elif k == "inertia_tensor": # v' = M * v * M^T - v = v.reshape(-1, ndim, ndim) + v = v.reshape(-1, ndim, ndim) + # v * M^T v = np.einsum("ijk, mk -> ijm", v, M) # M * v v = np.einsum("ij, kjm -> kim", M, v) + v = v.reshape(-1, ndim * ndim) elif k in ( "intensity_mean", @@ -344,28 +649,47 @@ def _transform_affine(k: str, v: np.ndarray, M: np.ndarray): "border_dist", ): pass + elif k == "pretrained_feats": + pass else: raise ValueError(f"Don't know how to affinely transform {k}") + + if np.isnan(v).any(): + logger.error(f"NaNs found in {k} after affine transformation") + return v class WRRandomAffine(WRBaseAugmentation): + """Random affine transformation augmentation. + + Affected properties: + - "area" + - "equivalent_diameter_area" + - "inertia_tensor" + - "coords" + """ def __init__( self, degrees: float = 10, scale: float = (0.9, 1.1), shear: float = (0.1, 0.1), p: float = 0.5, + scale_isotropic: float = (1., 1.), ): super().__init__(p) self.degrees = degrees if degrees is not None else 0 self.scale = scale if scale is not None else (1, 1) self.shear = shear if shear is not None else (0, 0) - + self.scale_isotropic = scale_isotropic + def _augment(self, features: WRFeatures): degrees = self._rng.uniform(-self.degrees, self.degrees) / 180 * np.pi + scale_iso = self._rng.uniform(*self.scale_isotropic) scale = self._rng.uniform(*self.scale, 3) + scale = scale * scale_iso + shy = self._rng.uniform(-self.shear[0], self.shear[0]) shx = self._rng.uniform(-self.shear[1], self.shear[1]) @@ -383,7 +707,7 @@ def _augment(self, features: WRFeatures): (k, _transform_affine(k, v, self._M)) for k, v in features.features.items() ) - return WRFeatures( + return self.return_type( coords=points, labels=features.labels, timepoints=features.timepoints, @@ -392,6 +716,11 @@ def _augment(self, features: WRFeatures): class WRRandomBrightness(WRBaseAugmentation): + """random brightness augmentation. + + Affected properties: + - "intensity" + """ def __init__( self, scale: tuple[float] = (0.5, 2.0), @@ -414,7 +743,7 @@ def _augment(self, features: WRFeatures): v = v * scale + shift key_vals.append((k, v)) feats = OrderedDict(key_vals) - return WRFeatures( + return self.return_type( coords=features.coords, labels=features.labels, timepoints=features.timepoints, @@ -423,6 +752,11 @@ def _augment(self, features: WRFeatures): class WRRandomOffset(WRBaseAugmentation): + """Random offset augmentation. + + Affected properties: + - "coords" + """ def __init__(self, offset: float = (-3, 3), p: float = 0.5): super().__init__(p) self.offset = offset @@ -431,7 +765,8 @@ def _augment(self, features: WRFeatures): offset = self._rng.uniform(*self.offset, features.coords.shape) coords = features.coords + offset - return WRFeatures( + + return self.return_type( coords=coords, labels=features.labels, timepoints=features.timepoints, @@ -440,7 +775,11 @@ def _augment(self, features: WRFeatures): class WRRandomMovement(WRBaseAugmentation): - """random global linear shift.""" + """Random global linear shift. + + Affected properties: + - "coords" + """ def __init__(self, offset: float = (-10, 10), p: float = 0.5): super().__init__(p) self.offset = offset @@ -451,7 +790,7 @@ def _augment(self, features: WRFeatures): offset = (features.timepoints[:, None] - tmin) * base_offset[None] coords = features.coords + offset - return WRFeatures( + return self.return_type( coords=coords, labels=features.labels, timepoints=features.timepoints, @@ -460,53 +799,171 @@ def _augment(self, features: WRFeatures): class WRAugmentationPipeline: - def __init__(self, augmentations: Sequence[WRBaseAugmentation]): + def __init__(self, augmentations: Sequence[WRBaseAugmentation], return_type=None): self.augmentations = augmentations + self.return_type = return_type if return_type is not None else WRFeatures + logger.debug(f"Augmentation pipeline return type: {self.return_type}") + for aug in self.augmentations: + aug.return_type = self.return_type def __call__(self, feats: WRFeatures): + # logger.debug(f"Applying {len(self.augmentations)} augmentations") for aug in self.augmentations: + # logger.debug(f"Applying {aug.__class__.__name__} augmentation") + aug.return_type = self.return_type + # logger.debug(f"Augmentation return type: {aug.return_type}") feats = aug(feats) + + # logger.debug(f"Coords : {feats.coords}") + return feats +# Factory functions + + +class AugmentationFactory: + default_args: ClassVar = { + "flip": {"p": 0.5}, + "affine": { + "p": 0.8, + "degrees": 180, + "scale": (0.5, 2), + "shear": (0.1, 0.1), + }, + "brightness": {"p": 0.8}, + "offset": {"p": 0.8, "offset": (-3, 3)}, + "movement": {"offset": (-10, 10), "p": 0.3}, + } + + @staticmethod + def create_augmentation_pipeline(augment: int, return_type: str = WRFeatures): + if augment == 0: + return None + elif augment == 1: + return WRAugmentationPipeline( + [ + WRRandomFlip(**AugmentationFactory.default_args["flip"]), + WRRandomAffine( + **AugmentationFactory.default_args["affine"] + ), + # wrfeat.WRRandomBrightness(p=0.8, factor=(0.5, 2.0)), + # wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)), + ], + return_type=return_type, + ) + elif augment == 2: + return WRAugmentationPipeline( + [ + WRRandomFlip(**AugmentationFactory.default_args["flip"]), + WRRandomAffine(**AugmentationFactory.default_args["affine"]), + WRRandomBrightness(**AugmentationFactory.default_args["brightness"]), + WRRandomOffset(**AugmentationFactory.default_args["offset"]), + ], + return_type=return_type, + ) + elif augment == 3: + return WRAugmentationPipeline( + [ + WRRandomFlip(**AugmentationFactory.default_args["flip"]), + WRRandomAffine( + p=0.8, + degrees=180, + scale=(0.9, 1.1), + shear=(0.1, 0.1), + scale_isotropic=(0.5, 2.0) + ), + WRRandomBrightness(**AugmentationFactory.default_args["brightness"]), + WRRandomMovement(**AugmentationFactory.default_args["movement"]), + WRRandomOffset(**AugmentationFactory.default_args["offset"]), + ], + return_type=return_type, + ) + elif augment == 4: + return WRAugmentationPipeline( + [ + WRRandomAffine( + p=0.8, + degrees=180, + scale=(0.9, 1.1), + shear=(0.1, 0.1), + scale_isotropic=(0.5, 2.0) + ), + # WRRandomMovement(**AugmentationFactory.default_args["movement"]), + # WRRandomOffset(**AugmentationFactory.default_args["offset"]), + ], + return_type=return_type, + ) + else: + raise ValueError(f"Invalid augment level {augment}") + + @staticmethod + def create_cropper(crop_size: tuple[int], ndim: int, return_type=WRFeatures): + if crop_size is not None: + return WRRandomCrop( + crop_size=crop_size, + ndim=ndim, + return_type=return_type, + ) + return None + def get_features( detections: np.ndarray, imgs: np.ndarray | None = None, - features: Literal["none", "wrfeat"] = "wrfeat", + features_type: Literal["none", "wrfeat", "pretrained_feats", "pretrained_feats_aug"] = "wrfeat", ndim: int = 2, n_workers=0, progbar_class=tqdm, + feature_extractor: FeatureExtractor | None = None, ) -> list[WRFeatures]: + """Extracts features from detections and images.""" detections = _check_dimensions(detections, ndim) imgs = _check_dimensions(imgs, ndim) logger.info(f"Extracting features from {len(detections)} detections") - if n_workers > 0: - features = joblib.Parallel(n_jobs=n_workers)( - joblib.delayed(WRFeatures.from_mask_img)( - # New axis for time component - mask=mask[np.newaxis, ...], - img=img[np.newaxis, ...], - t_start=t, + if features_type in ["none", "wrfeat"]: + if n_workers > 0: + logger.info(f"Using {n_workers} processes for feature extraction") + features = joblib.Parallel(n_jobs=n_workers, backend="loky")( + joblib.delayed(WRFeatures.from_mask_img)( + # New axis for time component + mask=mask[np.newaxis, ...].copy(), + img=img[np.newaxis, ...].copy(), + t_start=t, + ) + for t, (mask, img) in progbar_class( + enumerate(zip(detections, imgs)), + total=len(imgs), + desc="Extracting features", + ) ) - for t, (mask, img) in progbar_class( - enumerate(zip(detections, imgs)), - total=len(imgs), - desc="Extracting features", + else: + logger.info("Using single process for feature extraction") + features = tuple( + WRFeatures.from_mask_img( + mask=mask[np.newaxis, ...], + img=img[np.newaxis, ...], + t_start=t, + ) + for t, (mask, img) in progbar_class( + enumerate(zip(detections, imgs)), + total=len(imgs), + desc="Extracting features", + ) ) - ) + if features_type == "none": + for f in features: + f.features = OrderedDict() + elif features_type == "pretrained_feats" or features_type == "pretrained_feats_aug": + feature_extractor.precompute_image_embeddings(imgs) + features = [ + WRPretrainedFeatures.from_mask_img( + img=img[np.newaxis], mask=mask[np.newaxis], feature_extractor=feature_extractor, t_start=t, additional_properties=feature_extractor.additional_features, + ) + for t, (mask, img) in enumerate(zip(detections, imgs)) + ] else: - logger.info("Using single process for feature extraction") - features = tuple( - WRFeatures.from_mask_img( - mask=mask[np.newaxis, ...], - img=img[np.newaxis, ...], - t_start=t, - ) - for t, (mask, img) in progbar_class( - enumerate(zip(detections, imgs)), - total=len(imgs), - desc="Extracting features", - ) + raise ValueError( + f"Unknown feature extraction method {features}. Available: 'none', 'wrfeat' or 'pretrained_feats'." ) return features @@ -529,6 +986,7 @@ def _check_dimensions(x: np.ndarray, ndim: int): def build_windows( features: list[WRFeatures], window_size: int, progbar_class=tqdm ) -> list[dict]: + """Builds windows from a list of WRFeatures.""" windows = [] for t1, t2 in progbar_class( zip(range(0, len(features)), range(window_size, len(features) + 1)), @@ -543,13 +1001,14 @@ def build_windows( if len(feat) == 0: coords = np.zeros((0, feat.ndim), dtype=int) - + pt_feats = feat.pretrained_feats if feat.pretrained_feats is not None else None w = dict( coords=coords, t1=t1, labels=labels, timepoints=timepoints, features=feat.features_stacked, + pretrained_features=pt_feats, ) windows.append(w) @@ -558,23 +1017,32 @@ def build_windows( if __name__ == "__main__": - imgs = load_tiff_timeseries( - # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01", - "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01", - ) - masks = load_tiff_timeseries( - # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01_GT/TRA", - "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01_GT/TRA", - dtype=int, - ) - - features = get_features(detections=masks, imgs=imgs, ndim=3) - windows = build_windows(features, window_size=4) - - -# if __name__ == "__main__": -# y = np.zeros((1, 100, 100), np.uint8) -# y[:, 20:40, 20:60] = 1 -# x = y + np.random.normal(0, 0.1, y.shape) - -# f = WRFeatures.from_mask_img(y, x, properties=("intensity_mean", "area")) + # imgs = load_tiff_timeseries( + # # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01", + # "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01", + # ) + # masks = load_tiff_timeseries( + # # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01_GT/TRA", + # "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01_GT/TRA", + # dtype=int, + # ) + + # features = get_features(detections=masks, imgs=imgs, ndim=3) + # windows = build_windows(features, window_size=4) + + y = np.zeros((1, 100, 100), np.uint8) + y[:, 20:40, 20:60] = 1 + x = y + np.random.normal(0, 0.1, y.shape) + + f = WRFeatures.from_mask_img(y, x, properties='regionprops2') + + # f = WRFeatures.from_pretrained(y, x) + + augmenter = WRAugmentationPipeline([ + WRRandomAffine(degrees=10, scale=(0.9, 1.1), shear=(0.1, 0.1), p=0.5), + WRRandomBrightness(scale=(0.5, 2.0), shift=(-0.1, 0.1), p=0.5), + WRRandomOffset(offset=(-3, 3), p=0.5), + WRRandomMovement(offset=(-10, 10), p=0.5), + ]) + + f2 = augmenter(f) \ No newline at end of file diff --git a/trackastra/model/model.py b/trackastra/model/model.py index 3b2bb57..2c07b8f 100644 --- a/trackastra/model/model.py +++ b/trackastra/model/model.py @@ -32,7 +32,7 @@ def __init__( dropout=0.1, cutoff_spatial: int = 256, window: int = 16, - positional_bias: Literal["bias", "rope", "none"] = "bias", + positional_bias: Literal["bias", "rope", "none"] = "rope", positional_bias_n_spatial: int = 32, attn_dist_mode: str = "v0", ): @@ -135,137 +135,41 @@ def forward( return x -# class BidirectionalRelativePositionalAttention(RelativePositionalAttention): -# def forward( -# self, -# query1: torch.Tensor, -# query2: torch.Tensor, -# coords: torch.Tensor, -# padding_mask: torch.Tensor = None, -# ): -# B, N, D = query1.size() -# q1 = self.q_pro(query1) # (B, N, D) -# q2 = self.q_pro(query2) # (B, N, D) -# v1 = self.v_pro(query1) # (B, N, D) -# v2 = self.v_pro(query2) # (B, N, D) - -# # (B, nh, N, hs) -# q1 = q1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) -# v1 = v1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) -# q2 = q2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) -# v2 = v2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) - -# attn_mask = torch.zeros( -# (B, self.n_head, N, N), device=query1.device, dtype=q1.dtype -# ) - -# # add negative value but not too large to keep mixed precision loss from becoming nan -# attn_ignore_val = -1e3 - -# # spatial cutoff -# yx = coords[..., 1:] -# spatial_dist = torch.cdist(yx, yx) -# spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1) -# attn_mask.masked_fill_(spatial_mask, attn_ignore_val) - -# # dont add positional bias to self-attention if coords is None -# if coords is not None: -# if self._mode == "bias": -# attn_mask = attn_mask + self.pos_bias(coords) -# elif self._mode == "rope": -# q1, q2 = self.rot_pos_enc(q1, q2, coords) -# else: -# pass - -# dist = torch.cdist(coords, coords, p=2) -# attn_mask += torch.exp(-0.1 * dist.unsqueeze(1)) - -# # if given key_padding_mask = (B,N) then ignore those tokens (e.g. padding tokens) -# if padding_mask is not None: -# ignore_mask = torch.logical_or( -# padding_mask.unsqueeze(1), padding_mask.unsqueeze(2) -# ).unsqueeze(1) -# attn_mask.masked_fill_(ignore_mask, attn_ignore_val) - -# self.attn_mask = attn_mask.clone() - -# y1 = nn.functional.scaled_dot_product_attention( -# q1, -# q2, -# v1, -# attn_mask=attn_mask, -# dropout_p=self.dropout if self.training else 0, -# ) -# y2 = nn.functional.scaled_dot_product_attention( -# q2, -# q1, -# v2, -# attn_mask=attn_mask, -# dropout_p=self.dropout if self.training else 0, -# ) - -# y1 = y1.transpose(1, 2).contiguous().view(B, N, D) -# y1 = self.proj(y1) -# y2 = y2.transpose(1, 2).contiguous().view(B, N, D) -# y2 = self.proj(y2) -# return y1, y2 - - -# class BidirectionalCrossAttention(nn.Module): -# def __init__( -# self, -# coord_dim: int = 2, -# d_model=256, -# num_heads=4, -# dropout=0.1, -# window: int = 16, -# cutoff_spatial: int = 256, -# positional_bias: Literal["bias", "rope", "none"] = "bias", -# positional_bias_n_spatial: int = 32, -# ): -# super().__init__() -# self.positional_bias = positional_bias -# self.attn = BidirectionalRelativePositionalAttention( -# coord_dim, -# d_model, -# num_heads, -# cutoff_spatial=cutoff_spatial, -# n_spatial=positional_bias_n_spatial, -# cutoff_temporal=window, -# n_temporal=window, -# dropout=dropout, -# mode=positional_bias, -# ) - -# self.mlp = FeedForward(d_model) -# self.norm1 = nn.LayerNorm(d_model) -# self.norm2 = nn.LayerNorm(d_model) - -# def forward( -# self, -# x: torch.Tensor, -# y: torch.Tensor, -# coords: torch.Tensor, -# padding_mask: torch.Tensor = None, -# ): -# x = self.norm1(x) -# y = self.norm1(y) - -# # cross attention -# # setting coords to None disables positional bias -# x2, y2 = self.attn( -# x, -# y, -# coords=coords if self.positional_bias else None, -# padding_mask=padding_mask, -# ) -# # print(torch.norm(x2).item()/torch.norm(x).item()) -# x = x + x2 -# x = x + self.mlp(self.norm2(x)) -# y = y + y2 -# y = y + self.mlp(self.norm2(y)) - -# return x, y +class LearnedRoPERotation(nn.Module): + def __init__(self, coord_dim, feature_dim, rope_dim=None): + super().__init__() + self.rope_dim = rope_dim or feature_dim + # MLP to predict rotation angle(s) from coords + self.angle_mlp = nn.Sequential( + nn.Linear(coord_dim, 64), + nn.ReLU(), + nn.Linear(64, self.rope_dim // 2) # one angle per feature pair + ) + + def forward(self, features, coords): + """features: (B, N, D) + coords: (B, N, coord_dim). + """ + B, N, D = features.shape + assert D % 2 == 0, "Feature dim must be even for RoPE." + rope_dim = self.rope_dim + + # Predict angles (B, N, rope_dim//2) + angles = self.angle_mlp(coords[..., :]) # you can select which coords to use + # Expand to (B, N, rope_dim) + cos = torch.cos(angles) + sin = torch.sin(angles) + # Prepare features for rotation + f = features[..., :rope_dim].reshape(B, N, -1, 2) + x, y = f[..., 0], f[..., 1] + # Apply rotation + x_rot = x * cos - y * sin + y_rot = x * sin + y * cos + rotated = torch.stack([x_rot, y_rot], dim=-1).reshape(B, N, rope_dim) + # Concatenate with the rest of the features if needed + if rope_dim < D: + rotated = torch.cat([rotated, features[..., rope_dim:]], dim=-1) + return rotated class TrackingTransformer(torch.nn.Module): @@ -273,6 +177,8 @@ def __init__( self, coord_dim: int = 3, feat_dim: int = 0, + pretrained_feat_dim: int = 0, + reduced_pretrained_feat_dim: int = 128, d_model: int = 128, nhead: int = 4, num_encoder_layers: int = 4, @@ -288,12 +194,16 @@ def __init__( "none", "linear", "softmax", "quiet_softmax" ] = "quiet_softmax", attn_dist_mode: str = "v0", + disable_xy_coords: bool = False, + disable_all_coords: bool = False, ): super().__init__() - + self.config = dict( coord_dim=coord_dim, feat_dim=feat_dim, + pretrained_feat_dim=pretrained_feat_dim, + reduced_pretrained_feat_dim=reduced_pretrained_feat_dim, pos_embed_per_dim=pos_embed_per_dim, d_model=d_model, nhead=nhead, @@ -307,15 +217,33 @@ def __init__( feat_embed_per_dim=feat_embed_per_dim, causal_norm=causal_norm, attn_dist_mode=attn_dist_mode, + disable_xy_coords=disable_xy_coords, + disable_all_coords=disable_all_coords, ) - - # TODO remove, alredy present in self.config - # self.window = window - # self.feat_dim = feat_dim - # self.coord_dim = coord_dim - + + # TODO temp attr, add as train config arg + if pretrained_feat_dim > 0: + self.reduced_pretrained_feat_dim = reduced_pretrained_feat_dim + else: + self.reduced_pretrained_feat_dim = 0 + self._return_norms = True + self.norms = {} + + self._disable_xy_coords = disable_xy_coords + self._disable_all_coords = disable_all_coords + + if self._disable_all_coords: + coords_proj_dims = 0 + elif self._disable_xy_coords: + coords_proj_dims = pos_embed_per_dim + else: + coords_proj_dims = (1 + coord_dim) * pos_embed_per_dim + + feats_proj_dims = feat_dim * feat_embed_per_dim + self.proj = nn.Linear( - (1 + coord_dim) * pos_embed_per_dim + feat_dim * feat_embed_per_dim, d_model + coords_proj_dims + feats_proj_dims + self.reduced_pretrained_feat_dim, + d_model ) self.norm = nn.LayerNorm(d_model) @@ -363,17 +291,36 @@ def __init__( ) else: self.feat_embed = nn.Identity() - - self.pos_embed = PositionalEncoding( - cutoffs=(window,) + (spatial_pos_cutoff,) * coord_dim, - n_pos=(pos_embed_per_dim,) * (1 + coord_dim), - ) + + if pretrained_feat_dim > 0: + self.ptfeat_proj = nn.Sequential( + nn.Linear(pretrained_feat_dim, self.reduced_pretrained_feat_dim), + ) + self.ptfeat_norm = nn.LayerNorm(self.reduced_pretrained_feat_dim) + else: + self.ptfeat_proj = nn.Identity() + self.ptfeat_norm = nn.Identity() + + if self._disable_all_coords: + self.pos_embed = nn.Identity() + + elif self._disable_xy_coords: + self.pos_embed = PositionalEncoding( + cutoffs=(window,), + n_pos=(pos_embed_per_dim,), + ) + else: + self.pos_embed = PositionalEncoding( + cutoffs=(window,) + (spatial_pos_cutoff,) * coord_dim, + n_pos=(pos_embed_per_dim,) * (1 + coord_dim), + ) # self.pos_embed = NoPositionalEncoding(d=pos_embed_per_dim * (1 + coord_dim)) - def forward(self, coords, features=None, padding_mask=None): + def forward(self, coords, features=None, pretrained_features=None, padding_mask=None): assert coords.ndim == 3 and coords.shape[-1] in (3, 4) _B, _N, _D = coords.shape + device = coords.device.type # disable padded coords (such that it doesnt affect minimum) if padding_mask is not None: @@ -384,15 +331,67 @@ def forward(self, coords, features=None, padding_mask=None): min_time = coords[:, :, :1].min(dim=1, keepdims=True).values coords = coords - min_time - pos = self.pos_embed(coords) - - if features is None or features.numel() == 0: - features = pos + if self._disable_xy_coords: + coords_feat = coords[:, :, :1].clone() else: - features = self.feat_embed(features) - features = torch.cat((pos, features), axis=-1) + coords_feat = coords.clone() - features = self.proj(features) + if not self._disable_all_coords: + pos = self.pos_embed(coords_feat) + else: + pos = None + + if self._return_norms: + self.norms = {} + if not self._disable_all_coords: + self.norms["pos_embed"] = pos.norm(dim=-1).detach().cpu().mean().item() + self.norms["coords"] = coords_feat.norm(dim=-1).detach().cpu().mean().item() + + with torch.amp.autocast(enabled=False, device_type=device): + # Determine if we have any features to use + has_features = features is not None and features.numel() > 0 + has_pretrained = pretrained_features is not None and pretrained_features.numel() > 0 and self.config["pretrained_feat_dim"] > 0 + + if self._return_norms: + if has_features: + self.norms["features"] = features.norm(dim=-1).detach().cpu().mean().item() + if has_pretrained: + self.norms["pretrained_features"] = pretrained_features.norm(dim=-1).detach().cpu().mean().item() + + if not has_features and not has_pretrained: + if self._disable_all_coords: + raise ValueError("features is None and all coords are disabled. Please enable at least one of the two.") + features_out = pos + else: + # Start with features if present, else None + features_out = self.feat_embed(features) if has_features else None + if self._return_norms and has_features: + self.norms["features_out"] = features_out.norm(dim=-1).detach().cpu().mean().item() + + # Add pretrained features if configured + if self.config["pretrained_feat_dim"] > 0 and has_pretrained: + pt_features = self.ptfeat_proj(pretrained_features) + pt_features = self.ptfeat_norm(pt_features) + if self._return_norms: + self.norms["pt_features_out"] = pt_features.norm(dim=-1).detach().cpu().mean().item() + if features_out is not None: + features_out = torch.cat((features_out, pt_features), dim=-1) + else: + features_out = pt_features + + # Add encoded coords if not disabled + if not self._disable_all_coords: + if features_out is not None: + features_out = torch.cat((pos, features_out), axis=-1) + else: + features_out = pos + + features = self.proj(features_out) + if self._return_norms: + self.norms["features_cat"] = features_out.norm(dim=-1).detach().cpu().mean().item() + self.norms["features_proj"] = features.norm(dim=-1).detach().cpu().mean().item() + # Clamp input when returning to mixed precision + features = features.clamp(torch.finfo(torch.float16).min, torch.finfo(torch.float16).max) features = self.norm(features) x = features @@ -411,7 +410,10 @@ def forward(self, coords, features=None, padding_mask=None): y = self.head_y(y) # outer product is the association matrix (logits) - A = torch.einsum("bnd,bmd->bnm", x, y) + A = torch.einsum("bnd,bmd->bnm", x, y) # /math.sqrt(_D) + + if torch.any(torch.isnan(A)): + logger.error("NaN in A") return A diff --git a/trackastra/model/model_api.py b/trackastra/model/model_api.py index ecdb552..90fddbb 100644 --- a/trackastra/model/model_api.py +++ b/trackastra/model/model_api.py @@ -3,13 +3,14 @@ from pathlib import Path from typing import Literal +import dask.array as da import numpy as np import tifffile import torch import yaml from tqdm import tqdm -from ..data import build_windows, get_features, load_tiff_timeseries +from ..data import FeatureExtractor, build_windows, get_features, load_tiff_timeseries from ..tracking import TrackGraph, build_graph, track_greedy from ..utils import normalize from .model import TrackingTransformer @@ -21,12 +22,50 @@ class Trackastra: + """A transformer-based tracking model for time-lapse data. + + Trackastra links segmented objects across time frames by predicting + associations with a transformer model trained on diverse time-lapse videos. + + The model takes as input: + - A sequence of images of shape (T,(Z),Y,X) + - Corresponding instance segmentation masks of shape (T,(Z),Y,X) + + It supports multiple tracking modes: + - greedy_nodiv: Fast greedy linking without division + - greedy: Fast greedy linking with division + - ilp: Integer Linear Programming based linking (more accurate but slower) + + Examples: + >>> # Load example data + >>> from trackastra.data import example_data_bacteria + >>> imgs, masks = example_data_bacteria() + >>> + >>> # Load pretrained model and track + >>> model = Trackastra.from_pretrained("general_2d", device="cuda") + >>> track_graph = model.track(imgs, masks, mode="greedy") + """ + def __init__( self, transformer: TrackingTransformer, train_args: dict, device: Literal["cuda", "mps", "cpu", "automatic", None] = None, ): + """Initialize Trackastra model. + + Args: + transformer: The underlying transformer model. + train_args: Training configuration arguments. + device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). + """ + """Initialize Trackastra model. + + Args: + transformer: The underlying transformer model. + train_args: Training configuration arguments. + device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). + """ if device == "cuda": if torch.cuda.is_available(): self.device = "cuda" @@ -67,38 +106,76 @@ def __init__( self.transformer = transformer.to(self.device) self.train_args = train_args + self.imgs_path = None + self.masks_path = None + self.feature_extractor = None @classmethod - def from_folder(cls, dir: Path, device: str | None = None): + def from_folder(cls, dir: Path | str, device: str | None = None, checkpoint_path: str | None = None): + """Load a Trackastra model from a local folder. + + Args: + dir: Path to model folder containing: + - model weights + - train_config.yaml with training arguments + device: Device to run model on. + checkpoint_path: Path to model checkpoint file (defaults to "model.pt" in dir). + + Returns: + Trackastra model instance. + """ # Always load to cpu first - transformer = TrackingTransformer.from_folder(dir, map_location="cpu") + if checkpoint_path is None: + checkpoint_path = "model.pt" + transformer = TrackingTransformer.from_folder( + Path(dir).expanduser(), map_location="cpu", checkpoint_path=checkpoint_path + ) train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader) return cls(transformer=transformer, train_args=train_args, device=device) - # TODO make safer @classmethod def from_pretrained( cls, name: str, device: str | None = None, download_dir: Path | None = None ): + """Load a pretrained Trackastra model. + + Available pretrained models are described in detail in pretrained.json. + + Args: + name: Name of pretrained model (e.g. "general_2d"). + device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). + download_dir: Directory to download model to (defaults to ~/.cache/trackastra). + + Returns: + Trackastra model instance. + """ folder = download_pretrained(name, download_dir) # download zip from github to location/name, then unzip return cls.from_folder(folder, device=device) def _predict( self, - imgs: np.ndarray, - masks: np.ndarray, + imgs: np.ndarray | da.Array, + masks: np.ndarray | da.Array, edge_threshold: float = 0.05, n_workers: int = 0, + normalize_imgs: bool = True, progbar_class=tqdm, ): logger.info("Predicting weights for candidate graph") - imgs = normalize(imgs) + if normalize_imgs: + if isinstance(imgs, da.Array): + imgs = imgs.map_blocks(normalize) + else: + imgs = normalize(imgs) + self.transformer.eval() features = get_features( detections=masks, imgs=imgs, + features_type=self.train_args["features"], + feature_extractor=self.feature_extractor, ndim=self.transformer.config["coord_dim"], n_workers=n_workers, progbar_class=progbar_class, @@ -158,13 +235,68 @@ def _track_from_predictions( def track( self, - imgs: np.ndarray, - masks: np.ndarray, + imgs: np.ndarray | da.Array, + masks: np.ndarray | da.Array, mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", + normalize_imgs: bool = True, progbar_class=tqdm, + n_workers: int = 0, **kwargs, ) -> TrackGraph: - predictions = self._predict(imgs, masks, progbar_class=progbar_class) + """Track objects across time frames. + + This method links segmented objects across time frames using the specified + tracking mode. No hyperparameters need to be chosen beyond the tracking mode. + + Args: + imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array) + masks: Instance segmentation masks of shape (T,(Z),Y,X). + mode: Tracking mode: + - "greedy_nodiv": Fast greedy linking without division + - "greedy": Fast greedy linking with division + - "ilp": Integer Linear Programming based linking (more accurate but slower) + progbar_class: Progress bar class to use. + n_workers: Number of worker processes for feature extraction. + normalize_imgs: Whether to normalize the images. + **kwargs: Additional arguments passed to tracking algorithm. + + Returns: + TrackGraph containing the tracking results. + """ + if not imgs.shape == masks.shape: + raise RuntimeError( + f"Img shape {imgs.shape} and mask shape {masks.shape} do not match." + ) + + if not imgs.ndim == self.transformer.config["coord_dim"] + 1: + raise RuntimeError( + f"images should be a sequence of {self.transformer.config['coord_dim']}D images" + ) + feat_type = self.train_args["features"] + if feat_type == "pretrained_feats" or feat_type == "pretrained_feats_aug": + additional_features = self.train_args.get( + "pretrained_feats_additional_props", None + ) + if self.imgs_path is None: + save_path = "./embeddings" + else: + save_path = self.imgs_path / "embeddings" + self.feature_extractor = FeatureExtractor.from_model_name( + self.train_args["pretrained_feats_model"], + imgs.shape[-2:], + save_path=save_path, + mode=self.train_args["pretrained_feats_mode"], + device="cuda" if torch.cuda.is_available() else "cpu", + additional_features=additional_features, + ) + self.feature_extractor.force_recompute = True + predictions = self._predict( + imgs, + masks, + normalize_imgs=normalize_imgs, + progbar_class=progbar_class, + n_workers=n_workers, + ) track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs) return track_graph @@ -173,27 +305,38 @@ def track_from_disk( imgs_path: Path, masks_path: Path, mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", + normalize_imgs: bool = True, **kwargs, ) -> tuple[TrackGraph, np.ndarray]: - """Track directly from two series of tiff files. + """Track objects directly from image and mask files on disk. + + This method supports both single tiff files and directories Args: - imgs_path: - Options - - Directory containing a series of numbered tiff files. Each file contains an image of shape (C),(Z),Y,X. - - Single tiff file with time series of shape T,(C),(Z),Y,X. - masks_path: - Options - - Directory containing a series of numbered tiff files. Each file contains an image of shape (C),(Z),Y,X. - - Single tiff file with time series of shape T,(Z),Y,X. - mode (optional): - Mode for candidate graph pruning. + imgs_path: Path to input images. Can be: + - Directory containing numbered tiff files of shape (C),(Z),Y,X + - Single tiff file with time series of shape T,(C),(Z),Y,X + masks_path: Path to mask files. Can be: + - Directory containing numbered tiff files of shape (Z),Y,X + - Single tiff file with time series of shape T,(Z),Y,X + mode: Tracking mode: + - "greedy_nodiv": Fast greedy linking without division + - "greedy": Fast greedy linking with division + - "ilp": Integer Linear Programming based linking (more accurate but slower) + normalize_imgs: Whether to normalize the images. + **kwargs: Additional arguments passed to tracking algorithm. + + Returns: + Tuple of (TrackGraph, tracked masks). """ if not imgs_path.exists(): raise FileNotFoundError(f"{imgs_path=} does not exist.") if not masks_path.exists(): raise FileNotFoundError(f"{masks_path=} does not exist.") + self.imgs_path = imgs_path + self.masks_path = masks_path + if imgs_path.is_dir(): imgs = load_tiff_timeseries(imgs_path) else: @@ -226,4 +369,6 @@ def track_from_disk( f"Img shape {imgs.shape} and mask shape {masks. shape} do not match." ) - return self.track(imgs, masks, mode, **kwargs), masks + return self.track( + imgs, masks, mode, normalize_imgs=normalize_imgs, **kwargs + ), masks diff --git a/trackastra/model/model_parts.py b/trackastra/model/model_parts.py index 5f6e042..b25c8a1 100644 --- a/trackastra/model/model_parts.py +++ b/trackastra/model/model_parts.py @@ -8,6 +8,17 @@ import torch.nn.functional as F from torch import nn +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + _FLASH_ATTN = True and torch.cuda.is_available() +except ImportError: + flash_attn_varlen_qkvpacked_func = None + _FLASH_ATTN = False + +# if not _FLASH_ATTN: +# warnings.warn("flash_attn not found or not available for device, falling back to normal attention.") +# warnings.warn("Install with\n\npip install flash-attn --no-build-isolation\n\n") + from .rope import RotaryPositionalEncoding logger = logging.getLogger(__name__) @@ -58,7 +69,7 @@ def __init__( def forward(self, coords: torch.Tensor): _B, _N, D = coords.shape - assert D == len(self.freqs) + assert D == len(self.freqs), f"coords dim {D} must be equal to number of frequencies {len(self.freqs)}" embed = torch.cat( tuple( torch.cat( @@ -169,6 +180,13 @@ def __init__( mode: Literal["bias", "rope", "none"] = "bias", attn_dist_mode: str = 'v0' ): + """ + + attn_dist_mode: str + v0: exponential decay + v1: exponential decay with cutoff_spatial + v2: no masking (except padding_mask). + """ super().__init__() if not embed_dim % (2 * n_head) == 0: @@ -220,13 +238,13 @@ def __init__( self._mode = mode def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - coords: torch.Tensor, - padding_mask: torch.Tensor = None, - ): + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + coords: torch.Tensor, + padding_mask: torch.Tensor = None, + ): B, N, D = query.size() q = self.q_pro(query) # (B, N, D) k = self.k_pro(key) # (B, N, D) @@ -243,11 +261,12 @@ def forward( # add negative value but not too large to keep mixed precision loss from becoming nan attn_ignore_val = -1e3 - # spatial cutoff - yx = coords[..., 1:] - spatial_dist = torch.cdist(yx, yx) - spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1) - attn_mask.masked_fill_(spatial_mask, attn_ignore_val) + if self.attn_dist_mode != 'v2': + # spatial cutoff + yx = coords[..., 1:] + spatial_dist = torch.cdist(yx, yx) + spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1) + attn_mask.masked_fill_(spatial_mask, attn_ignore_val) # dont add positional bias to self-attention if coords is None if coords is not None: @@ -263,6 +282,8 @@ def forward( attn_mask += torch.exp(-0.1 * dist.unsqueeze(1)) elif self.attn_dist_mode == 'v1': attn_mask += torch.exp(-5 * spatial_dist.unsqueeze(1) / self.cutoff_spatial) + elif self.attn_dist_mode == 'v2': + pass else: raise ValueError(f"Unknown attn_dist_mode {self.attn_dist_mode}") @@ -271,15 +292,200 @@ def forward( ignore_mask = torch.logical_or( padding_mask.unsqueeze(1), padding_mask.unsqueeze(2) ).unsqueeze(1) - attn_mask.masked_fill_(ignore_mask, attn_ignore_val) + if self.attn_dist_mode == 'v2': + attn_mask = ~ignore_mask + else: + attn_mask.masked_fill_(ignore_mask, attn_ignore_val) - # self.attn_mask = attn_mask.clone() + self.attn_mask = attn_mask.clone() - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0 - ) + if _FLASH_ATTN and self.attn_dist_mode == 'v2' and False: # Disable for now + y = compute_attention_with_unpadding(q, k, v, padding_mask) + else: + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0 + ) y = y.transpose(1, 2).contiguous().view(B, N, D) # output projection y = self.proj(y) return y + + +def compute_attention_with_unpadding(q, k, v, padding_mask): + """Compute self-attention using flash_attn_varlen_qkvpacked_func with unpadding. + + Args: + q, k, v: Tensors of shape (B, H, N, D) + padding_mask: Tensor of shape (B, N), where True means the element should be ignored. + + Returns: + output: Tensor of shape (B, H, N, D), the result of the attention computation. + """ + B, H, N, D = q.shape + assert q.shape == k.shape == v.shape + assert padding_mask.shape == (B, N) + + # Extract sequence lengths and create cumulative sequence lengths + valid_tokens_mask = ~padding_mask # Flip the mask so True means valid + lens = valid_tokens_mask.sum(dim=-1).tolist() # Length of each sequence + cu_seqlens = torch.tensor([0, *torch.cumsum(torch.tensor(lens), dim=0).tolist()], dtype=torch.int32, device=q.device) + + # Unpad Q, K, V + q_unpadded = q.transpose(1, 2)[valid_tokens_mask] # Shape: (total_tokens, H, D) + k_unpadded = k.transpose(1, 2)[valid_tokens_mask] # Shape: (total_tokens, H, D) + v_unpadded = v.transpose(1, 2)[valid_tokens_mask] # Shape: (total_tokens, H, D) + + # Stack Q, K, V into a single tensor for FlashAttention + qkv_unpadded = torch.stack([q_unpadded, k_unpadded, v_unpadded], dim=1) # Shape: (total_tokens, 3, H, D) + + qkv_unpadded = qkv_unpadded.bfloat16() + # FlashAttention + max_seqlen = max(lens) # Maximum sequence length in the batch + output_unpadded = flash_attn_varlen_qkvpacked_func( + qkv_unpadded, # (total_tokens, 3, H, D) + cu_seqlens, # (B + 1,) + max_seqlen=max_seqlen, + dropout_p=0.0, # Set to 0.0 for evaluation + causal=False, + ) # Output: (total_tokens, H, D) + + output_unpadded = output_unpadded.to(q.dtype) + # Re-pad to original dimensions + output_padded = torch.zeros((B, N, H, D), dtype=output_unpadded.dtype, device=output_unpadded.device) + output_padded[valid_tokens_mask] = output_unpadded + output_padded = output_padded.transpose(1, 2) # Shape: (B, H, N, D) + + return output_padded + +# class BidirectionalRelativePositionalAttention(RelativePositionalAttention): +# def forward( +# self, +# query1: torch.Tensor, +# query2: torch.Tensor, +# coords: torch.Tensor, +# padding_mask: torch.Tensor = None, +# ): +# B, N, D = query1.size() +# q1 = self.q_pro(query1) # (B, N, D) +# q2 = self.q_pro(query2) # (B, N, D) +# v1 = self.v_pro(query1) # (B, N, D) +# v2 = self.v_pro(query2) # (B, N, D) + +# # (B, nh, N, hs) +# q1 = q1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) +# v1 = v1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) +# q2 = q2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) +# v2 = v2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) + +# attn_mask = torch.zeros( +# (B, self.n_head, N, N), device=query1.device, dtype=q1.dtype +# ) + +# # add negative value but not too large to keep mixed precision loss from becoming nan +# attn_ignore_val = -1e3 + +# # spatial cutoff +# yx = coords[..., 1:] +# spatial_dist = torch.cdist(yx, yx) +# spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1) +# attn_mask.masked_fill_(spatial_mask, attn_ignore_val) + +# # dont add positional bias to self-attention if coords is None +# if coords is not None: +# if self._mode == "bias": +# attn_mask = attn_mask + self.pos_bias(coords) +# elif self._mode == "rope": +# q1, q2 = self.rot_pos_enc(q1, q2, coords) +# else: +# pass + +# dist = torch.cdist(coords, coords, p=2) +# attn_mask += torch.exp(-0.1 * dist.unsqueeze(1)) + +# # if given key_padding_mask = (B,N) then ignore those tokens (e.g. padding tokens) +# if padding_mask is not None: +# ignore_mask = torch.logical_or( +# padding_mask.unsqueeze(1), padding_mask.unsqueeze(2) +# ).unsqueeze(1) +# attn_mask.masked_fill_(ignore_mask, attn_ignore_val) + +# self.attn_mask = attn_mask.clone() + +# y1 = nn.functional.scaled_dot_product_attention( +# q1, +# q2, +# v1, +# attn_mask=attn_mask, +# dropout_p=self.dropout if self.training else 0, +# ) +# y2 = nn.functional.scaled_dot_product_attention( +# q2, +# q1, +# v2, +# attn_mask=attn_mask, +# dropout_p=self.dropout if self.training else 0, +# ) + +# y1 = y1.transpose(1, 2).contiguous().view(B, N, D) +# y1 = self.proj(y1) +# y2 = y2.transpose(1, 2).contiguous().view(B, N, D) +# y2 = self.proj(y2) +# return y1, y2 + + +# class BidirectionalCrossAttention(nn.Module): +# def __init__( +# self, +# coord_dim: int = 2, +# d_model=256, +# num_heads=4, +# dropout=0.1, +# window: int = 16, +# cutoff_spatial: int = 256, +# positional_bias: Literal["bias", "rope", "none"] = "bias", +# positional_bias_n_spatial: int = 32, +# ): +# super().__init__() +# self.positional_bias = positional_bias +# self.attn = BidirectionalRelativePositionalAttention( +# coord_dim, +# d_model, +# num_heads, +# cutoff_spatial=cutoff_spatial, +# n_spatial=positional_bias_n_spatial, +# cutoff_temporal=window, +# n_temporal=window, +# dropout=dropout, +# mode=positional_bias, +# ) + +# self.mlp = FeedForward(d_model) +# self.norm1 = nn.LayerNorm(d_model) +# self.norm2 = nn.LayerNorm(d_model) + +# def forward( +# self, +# x: torch.Tensor, +# y: torch.Tensor, +# coords: torch.Tensor, +# padding_mask: torch.Tensor = None, +# ): +# x = self.norm1(x) +# y = self.norm1(y) + +# # cross attention +# # setting coords to None disables positional bias +# x2, y2 = self.attn( +# x, +# y, +# coords=coords if self.positional_bias else None, +# padding_mask=padding_mask, +# ) +# # print(torch.norm(x2).item()/torch.norm(x).item()) +# x = x + x2 +# x = x + self.mlp(self.norm2(x)) +# y = y + y2 +# y = y + self.mlp(self.norm2(y)) + +# return x, y \ No newline at end of file diff --git a/trackastra/model/predict.py b/trackastra/model/predict.py index afcfefb..efb4e82 100644 --- a/trackastra/model/predict.py +++ b/trackastra/model/predict.py @@ -17,26 +17,41 @@ def predict(batch, model): - """Args: - batch (_type_): _description_ - model (_type_): _description_. - + """Predict association scores between objects in a batch. + + Args: + batch: Dictionary containing: + - features: Object features array + - coords: Object coordinates array + - timepoints: Time points array + model: TrackingTransformer model to use for prediction. + Returns: - _type_: _description_ + Array of association scores between objects. """ - feats = torch.from_numpy(batch["features"]) + if batch["features"] is not None: + feats = torch.from_numpy(batch["features"]) + else: + feats = None + if batch["pretrained_features"] is not None: + pretrained_feats = torch.from_numpy(batch["pretrained_features"]) + else: + pretrained_feats = None coords = torch.from_numpy(batch["coords"]) timepoints = torch.from_numpy(batch["timepoints"]).long() # Hack that assumes that all parameters of a model are on the same device device = next(model.parameters()).device - feats = feats.unsqueeze(0).to(device) - timepoints = timepoints.unsqueeze(0).to(device) coords = coords.unsqueeze(0).to(device) + timepoints = timepoints.unsqueeze(0).to(device) + if feats is not None: + feats = feats.unsqueeze(0).to(device) + if pretrained_feats is not None: + pretrained_feats = pretrained_feats.unsqueeze(0).to(device) # Concat timepoints to coordinates coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2) with torch.no_grad(): - A = model(coords, features=feats) + A = model(coords, features=feats, pretrained_features=pretrained_feats) A = model.normalize_output(A, timepoints, coords) # # Spatially far entries should not influence the causal normalization @@ -60,20 +75,39 @@ def predict_windows( edge_threshold: float = 0.05, spatial_dim: int = 3, progbar_class=tqdm, + pred_func_override=None, ) -> dict: - """_summary_. - + """Predict associations between objects across sliding windows. + + This function processes a sequence of sliding windows to predict associations + between objects across time frames. It handles: + - Object tracking across time + - Weight normalization across windows + - Edge thresholding + - Time-based filtering + Args: - windows (_type_): _description_ - features (_type_): _description_ - model (_type_): _description_ - intra_window_weight (_type_, optional): _description_. Defaults to 0. - delta_t (_type_, optional): _description_. Defaults to 1. - edge_threshold (_type_, optional): _description_. Defaults to 0.05. - spatial_dim: Dimensionality of the input masks. This might be < model.coord_dim + windows: List of window dictionaries containing: + - timepoints: Array of time points + - labels: Array of object labels + - features: Object features + - coords: Object coordinates + features: List of feature objects containing: + - labels: Object labels + - timepoints: Time points + - coords: Object coordinates + model: TrackingTransformer model to use for prediction. + intra_window_weight: Weight factor for objects in middle of window. Defaults to 0. + delta_t: Maximum time difference between objects to consider. Defaults to 1. + edge_threshold: Minimum association score to consider. Defaults to 0.05. + spatial_dim: Dimensionality of input masks. May be less than model.coord_dim. + progbar_class: Progress bar class to use. Defaults to tqdm. + pred_func_override: Function to override the prediction function. This is useful for debugging or testing other prediction methods. Returns: - _type_: _description_ + Dictionary containing: + - nodes: List of node properties (id, coords, time, label) + - weights: Tuple of ((node_i, node_j), weight) pairs """ # first get all objects/coords time_labels_to_id = dict() @@ -109,10 +143,17 @@ def predict_windows( # This assumes that the samples in the dataset are ordered by time and start at 0. batch = windows[t] timepoints = batch["timepoints"] + if isinstance(timepoints, torch.Tensor): + timepoints = timepoints.cpu().numpy() labels = batch["labels"] - - A = predict(batch, model) - + if isinstance(labels, torch.Tensor): + labels = labels.cpu().numpy() + + if pred_func_override is None: + A = predict(batch, model) + else: + A = pred_func_override(batch) + dt = timepoints[None, :] - timepoints[:, None] time_mask = np.logical_and(dt <= delta_t, dt > 0) A[~time_mask] = 0 diff --git a/trackastra/model/pretrained.json b/trackastra/model/pretrained.json index 0d7e45d..a00137d 100644 --- a/trackastra/model/pretrained.json +++ b/trackastra/model/pretrained.json @@ -1,9 +1,9 @@ { "general_2d": { - "tags": ["cells, nuclei, bacteria, epithelial"], + "tags": ["cells, nuclei, bacteria, epithelial, yeast, particles"], "dimensionality": [2], - "description": "For tracking fluorescent nuclei, bacteria (PhC), whole cells (BF, PhC, DIC), epithelial cells with fluorescent membrane.", - "url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.1.1/general_2d.zip", + "description": "For tracking fluorescent nuclei, bacteria (PhC), whole cells (BF, PhC, DIC), epithelial cells with fluorescent membrane, budding yeast cells (PhC), fluorescent particles, .", + "url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/general_2d.zip", "datasets": { "Subset of Cell Tracking Challenge 2d datasets": { "url": "https://celltrackingchallenge.net/2d-datasets/", @@ -17,6 +17,10 @@ "url": "https://zenodo.org/records/7260137", "reference": "Seiffarth J, Scherr T, Wollenhaupt B, Neumann O, Scharr H, Kohlheyer D, Mikut R, Nöh K. ObiWan-Microbi: OMERO-based integrated workflow for annotating microbes in the cloud. SoftwareX. 2024 May 1;26:101638." }, + "Bacteria Persat": { + "url": "https://www.p-lab.science", + "reference": "Datasets kindly provided by Persat lab, EPFL." + }, "DeepCell": { "url": "https://datasets.deepcell.org/data", "reference": "Schwartz, M, Moen E, Miller G, Dougherty T, Borba E, Ding R, Graf W, Pao E, Van Valen D. Caliban: Accurate cell tracking and lineage construction in live-cell imaging experiments with deep learning. Biorxiv. 2023 Sept 13:803205." @@ -36,14 +40,37 @@ }, "Synthetic nuclei": { "reference": "Weigert group live cell simulator." + }, + "Synthetic particles": { + "reference": "Weigert group particle simulator." + }, + "Particle Tracking Challenge": { + "url": "http://bioimageanalysis.org/track/#data", + "reference": "Chenouard, N., Smal, I., De Chaumont, F., Maška, M., Sbalzarini, I. F., Gong, Y., ... & Meijering, E. (2014). Objective comparison of particle tracking methods. Nature methods, 11(3), 281-289." + }, + "Yeast Cell-ACDC": { + "url": "https://zenodo.org/records/6795124", + "reference": "Padovani, F., Mairhörmann, B., Falter-Braun, P., Lengefeld, J., & Schmoller, K. M. (2022). Segmentation, tracking and cell cycle analysis of live-cell imaging data with Cell-ACDC. BMC biology, 20(1), 174." + }, + "DeepSea": { + "url": "https://deepseas.org/datasets/", + "reference": "Zargari, A., Lodewijk, G. A., Mashhadi, N., Cook, N., Neudorf, C. W., Araghbidikashani, K., ... & Shariati, S. A. (2023). DeepSea is an efficient deep-learning model for single-cell segmentation and tracking in time-lapse microscopy. Cell Reports Methods, 3(6)." + }, + "Btrack" : { + "url": "https://rdr.ucl.ac.uk/articles/dataset/Cell_tracking_reference_dataset/16595978", + "reference": "Ulicna, K., Vallardi, G., Charras, G., & Lowe, A. R. (2021). Automated deep lineage tree analysis using a Bayesian single cell tracking approach. Frontiers in Computer Science, 3, 734559." + }, + "E. coli in mother machine": { + "url": "https://zenodo.org/records/11237127", + "reference": "O’Connor, O. M., & Dunlop, M. J. (2024). Cell-TRACTR: A transformer-based model for end-to-end segmentation and tracking of cells. bioRxiv, 2024-07." } } }, "ctc": { - "tags": ["ctc", "isbi2024"], + "tags": ["ctc", "Cell Tracking Challenge", "Cell Linking Benchmark"], "dimensionality": [2, 3], - "description": "For tracking Cell Tracking Challenge datasets. Winner of the ISBI 2024 CTC generalizable linking challenge.", - "url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.1/ctc.zip", + "description": "For tracking Cell Tracking Challenge datasets. This is the successor of the winning model of the ISBI 2024 CTC generalizable linking challenge.", + "url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/ctc.zip", "datasets": { "All Cell Tracking Challenge 2d+3d datasets with available GT and ERR_SEG": { "url": "https://celltrackingchallenge.net/3d-datasets/", diff --git a/trackastra/model/pretrained.py b/trackastra/model/pretrained.py index fcb86e9..3a3c047 100644 --- a/trackastra/model/pretrained.py +++ b/trackastra/model/pretrained.py @@ -2,6 +2,7 @@ import shutil import tempfile import zipfile +from importlib.resources import files from pathlib import Path import requests @@ -10,10 +11,8 @@ logger = logging.getLogger(__name__) _MODELS = { - "ctc": ( - "https://github.com/weigertlab/trackastra-models/releases/download/v0.1/ctc.zip" - ), - "general_2d": "https://github.com/weigertlab/trackastra-models/releases/download/v0.1.1/general_2d.zip", + "ctc": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/ctc.zip", + "general_2d": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/general_2d.zip", } @@ -57,7 +56,7 @@ def download(url: str, fname: Path): def download_pretrained(name: str, download_dir: Path | None = None): # TODO make safe, introduce versioning if download_dir is None: - download_dir = Path("~/.trackastra/.models").expanduser() + download_dir = files("trackastra").joinpath(".models") else: download_dir = Path(download_dir) diff --git a/trackastra/model/rope.py b/trackastra/model/rope.py index e1f027a..5beef41 100644 --- a/trackastra/model/rope.py +++ b/trackastra/model/rope.py @@ -75,6 +75,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor): co = co.unsqueeze(1).repeat_interleave(2, dim=-1) si = si.unsqueeze(1).repeat_interleave(2, dim=-1) + q2 = q * co + _rotate_half(q) * si k2 = k * co + _rotate_half(k) * si @@ -84,12 +85,14 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor): if __name__ == "__main__": model = RotaryPositionalEncoding((256, 256), (32, 32)) - x = 100 * torch.rand(1, 17, 2) - q = torch.rand(1, 4, 17, 64) - k = torch.rand(1, 4, 17, 64) + x = 100 * torch.randn(1, 17, 2) + q = torch.randn(1, 4, 17, 64) + k = torch.randn(1, 4, 17, 64) q1, k1 = model(q, k, x) A1 = q1[:, :, 0] @ k1[:, :, 0].transpose(-1, -2) q2, k2 = model(q, k, x + 10) A2 = q2[:, :, 0] @ k2[:, :, 0].transpose(-1, -2) + + print("close", torch.allclose(A1, A2)) \ No newline at end of file diff --git a/trackastra/utils/__init__.py b/trackastra/utils/__init__.py index b9f3239..e1fb217 100644 --- a/trackastra/utils/__init__.py +++ b/trackastra/utils/__init__.py @@ -1,8 +1,10 @@ # ruff: noqa: F401 from .utils import ( + add_timepoints_to_coords, blockwise_causal_norm, blockwise_sum, + masks2properties, normalize, preallocate_memory, random_label_cmap, diff --git a/trackastra/utils/utils.py b/trackastra/utils/utils.py index 730268c..7cfc767 100644 --- a/trackastra/utils/utils.py +++ b/trackastra/utils/utils.py @@ -6,9 +6,12 @@ from pathlib import Path from timeit import default_timer +import dask.array as da import matplotlib import numpy as np import torch +from skimage.measure import regionprops +from tqdm import tqdm logger = logging.getLogger(__name__) @@ -159,6 +162,73 @@ def _blockwise_sum_with_bounds(A: torch.Tensor, bounds: torch.Tensor, dim: int = return B +def add_timepoints_to_coords(coords, timepoints): + """Adds timepoints as the first dimension to the coordinates. + + Args: + coords (np.ndarray): Coordinates of shape (N, 3) or (N, 2). + timepoints (np.ndarray): Timepoints of shape (N,). + + Returns: + np.ndarray: Coordinates with timepoints added as the first dimension of shape (N, 3) or (N, 4). + """ + if coords.ndim not in [2] or coords.shape[1] not in [2, 3]: + raise ValueError("coords must be a 2D array with shape (N, 2) or (N, 3).") + if timepoints.ndim != 1 or timepoints.shape[0] != coords.shape[0]: + raise ValueError("timepoints must be a 1D array with the same length as the first dimension of coords.") + + return np.column_stack((timepoints, coords)) + + +def masks2properties(imgs, masks, return_props_by_time=False): + """Turn label masks into lists of properties, sorted (ascending) by time and label id. + + Args: + imgs (np.ndarray): T, (Z), H, W + masks (np.ndarray): T, (Z), H, W + return_props_by_time (bool): If True, return properties by time + (dict with keys 'coords' and 'labels' for each timepoint) + + Returns: + labels: List of labels + ts: List of timepoints + coords: List of coordinates + """ + # Get coordinates, timepoints, and labels of detections + labels = [] + ts = [] + coords = [] + properties_by_time = dict() + assert len(imgs) == len(masks) + for _t, frame in tqdm( + enumerate(masks), + # total=len(detections), + leave=False, + desc="Loading masks and properties", + ): + regions = regionprops(frame) + t_labels = [] + t_ts = [] + t_coords = [] + for _r in regions: + t_labels.append(_r.label) + t_ts.append(_t) + centroid = np.array(_r.centroid).astype(np.float16) + t_coords.append(centroid) + + properties_by_time[_t] = dict(coords=t_coords, labels=t_labels) + labels.extend(t_labels) + ts.extend(t_ts) + coords.extend(t_coords) + + labels = np.array(labels, dtype=int) + ts = np.array(ts, dtype=int) + coords = np.array(coords, dtype=np.float16) + if return_props_by_time: + return labels, ts, coords, properties_by_time + return labels, ts, coords + + def _bounds_from_timepoints(timepoints: torch.Tensor): assert timepoints.ndim == 1 bounds = torch.cat( @@ -263,7 +333,8 @@ def blockwise_causal_norm( A = torch.sigmoid(A) if mask_invalid is not None: assert mask_invalid.shape == A.shape - A[mask_invalid] = 0 + # A[mask_invalid] = 0 + A = A.masked_fill(mask_invalid, 0) u0, u1 = A, A ma0 = ma1 = 0 @@ -304,9 +375,22 @@ def normalize_tensor(x: torch.Tensor, dim: int | None = None, eps: float = 1e-8) return (x - mi) / (ma - mi + eps) -def normalize(x: np.ndarray): - mi, ma = np.percentile(x, (1, 99.8)).astype(np.float32) - return (x - mi) / (ma - mi + 1e-8) +def normalize(x: np.ndarray | da.Array, subsample: int | None = 4): + """Percentile normalize the image. + + If subsample is not None, calculate the percentile values over a subsampled image (last two axis) + which is way faster for large images. + """ + x = x.astype(np.float32) + if subsample is not None and all(s > 64 * subsample for s in x.shape[-2:]): + y = x[..., ::subsample, ::subsample] + else: + y = x + + mi, ma = np.percentile(y, (1, 99.8)).astype(np.float32) + x -= mi + x /= ma - mi + 1e-8 + return x def batched(x, batch_size, device): @@ -460,3 +544,11 @@ def str2path(x: str) -> Path: tps = torch.repeat_interleave(torch.arange(5), 10) C = blockwise_causal_norm(A, tps) + + +def percentile_norm(b): + for i, im in enumerate(b): + p1, p99 = np.percentile(im, (1, 99.8)) + b[i] = (im - p1) / (p99 - p1) + b[i] = np.clip(b[i], 0, 1) + return b