Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 83 additions & 3 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,22 @@ def _common_step(self, batch, eps=torch.finfo(torch.float32).eps):
padding_mask = batch["padding_mask"]
padding_mask = padding_mask.bool()

A_pred = self.model(coords, feats, padding_mask=padding_mask)
pretrained_feats = batch.get("pretrained_feats", None)
if pretrained_feats is not None and pretrained_feats.numel() > 0:
pretrained_feats = pretrained_feats.to(coords.device)
else:
pretrained_feats = None

if pretrained_feats is not None:
A_pred = self.model(
coords,
feats,
pretrained_features=pretrained_feats,
padding_mask=padding_mask,
)
else:
A_pred = self.model(coords, feats, padding_mask=padding_mask)

# remove inf values that might happen due to float16 numerics
A_pred.clamp_(torch.finfo(torch.float16).min, torch.finfo(torch.float16).max)

Expand Down Expand Up @@ -632,7 +647,8 @@ def on_validation_end(self, trainer, pl_module):

def create_run_name(args):
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# name = f"{timestamp}_{args.name}_feats_{args.features}_pos_{args.attn_positional_bias}_causal_norm_{args.causal_norm}"
# name = f"{timestamp}_{args.name}_feats_{args.features}_pos_" + \
# f"{args.attn_positional_bias}_causal_norm_{args.causal_norm}"
if args.timestamp:
name = f"{timestamp}_{args.name}"
else:
Expand Down Expand Up @@ -817,6 +833,9 @@ def train(args):
sanity_dist=args.sanity_dist,
crop_size=args.crop_size,
compress=args.compress,
pretrained_feats_model=args.pretrained_feats_model,
pretrained_feats_mode=args.pretrained_feats_mode,
pretrained_feats_additional_props=args.pretrained_feats_additional_props,
)
sampler_kwargs = dict(
batch_size=args.batch_size,
Expand Down Expand Up @@ -920,7 +939,8 @@ def train(args):
# Compiling does not work!
# model_lightning = torch.compile(model_lightning)

# if logdir already exists and --resume option is set, load the last checkpoint (eg when continuing training after crash)
# if logdir already exists and --resume option is set,
# load the last checkpoint (eg when continuing training after crash)
if logdir is not None and logdir.exists() and args.resume:
logging.info("logdir exists, loading last state of model")
fpath = model_lightning.checkpoint_path(logdir)
Expand Down Expand Up @@ -1065,9 +1085,69 @@ def parse_train_args():
"patch",
"patch_regionprops",
"wrfeat",
"pretrained_feats",
"pretrained_feats_aug",
],
default="wrfeat",
)
parser.add_argument(
"--pretrained_feats_model",
type=str,
default=None,
help="Model name for pretrained feature extraction (e.g. facebook/sam2.1-hiera-base-plus)",
)
parser.add_argument(
"--pretrained_feats_mode",
type=str,
default="mean_patches_exact",
help="Pooling mode for pretrained features",
)
parser.add_argument(
"--pretrained_feats_additional_props",
type=none_or_str,
default=None,
help="Additional region properties to concatenate with pretrained features (e.g. regionprops_small)",
)
parser.add_argument(
"--pretrained_n_augs",
type=int,
default=15,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note : this may be a bit high, depending on the dataset size. I'd recommend starting with much lower values and leaning on the feature disambiguation to avoid overfitting

help="Number of augmented dataset copies to create for pretrained features extraction",
)
parser.add_argument(
"--reduced_pretrained_feat_dim",
type=int,
default=None,
help="Reduce pretrained feature dimension via PCA to this size",
Copy link
Copy Markdown
Contributor

@C-Achard C-Achard Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it does not look like you did explicitely re-implement the PCA dimred I used at some point (but that never made it into the final pipeline), I think this refers to the dim of pretrained features after a single FCL as in https://github.com/C-Achard/trackastra/blob/a238b2cadc8e3b954c4af4afeba6df8faf18be71/trackastra/model/model.py#L296.
So this should not mention PCA, but rather the dim of pretrained features you'd like to feed to the encoder (which gets concatenated to the additional region props)

)
parser.add_argument(
"--rotate_features",
type=str2bool,
default=True,
help="Apply feature disambiguation to pretrained features based on coordinates to mitigate overfitting and avoid proximity-induced ambiguity in pretrained features",
)
Comment thread
anwai98 marked this conversation as resolved.
parser.add_argument(
"--disable_all_coords",
type=str2bool,
default=False,
)
parser.add_argument(
"--disable_xy_coords",
type=str2bool,
default=False,
)
parser.add_argument(
"--pretrained_model_path",
type=none_or_str,
default=None,
help="Path to a local pretrained model folder (overrides --model for loading weights)",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.0,
help="AdamW weight decay",
)
parser.add_argument(
"--causal_norm",
type=str,
Expand Down
2 changes: 0 additions & 2 deletions trackastra/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# ruff: noqa: F401

# Core data utilities (no training dependencies required)
from .data import (
CTCData,
Expand Down
75 changes: 62 additions & 13 deletions trackastra/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,19 @@ def __init__(
self.detection_folders = detection_folders
self.ndim = ndim
self.features = features
self.pretrained_feats_model = kwargs.get("pretrained_feats_model", None)
self.pretrained_feats_mode = kwargs.get(
"pretrained_feats_mode", "mean_patches_exact"
)
self.pretrained_feats_additional_props = kwargs.get(
"pretrained_feats_additional_props", None
)

if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]:
if (
features
not in ("none", "wrfeat", "pretrained_feats", "pretrained_feats_aug")
and features not in _PROPERTIES[ndim]
):
raise ValueError(
f"'{features}' not one of the supported {ndim}D features"
f" {tuple(_PROPERTIES[ndim].keys())}"
Expand Down Expand Up @@ -225,7 +236,7 @@ def __init__(

start = default_timer()

if self.features == "wrfeat":
if self.features in ("wrfeat", "pretrained_feats", "pretrained_feats_aug"):
self.windows = self._load_wrfeat()
else:
self.windows = self._load()
Expand Down Expand Up @@ -295,7 +306,7 @@ def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict):

start = default_timer()

if self.features == "wrfeat":
if self.features in ("wrfeat", "pretrained_feats", "pretrained_feats_aug"):
self.windows = self._load_wrfeat()
else:
self.windows = self._load()
Expand Down Expand Up @@ -343,7 +354,7 @@ def _setup_features_augs(
default_augmenter,
)

if self.features == "wrfeat":
if self.features in ("wrfeat", "pretrained_feats", "pretrained_feats_aug"):
return self._setup_features_augs_wrfeat(ndim, features, augment, crop_size)

cropper = (
Expand Down Expand Up @@ -774,7 +785,8 @@ def _build_windows(
f" {det_folder}:{t1}"
)

# build matrix from incomplete labels, but full lineage graph. If a label is missing, I should skip over it.
# build matrix from incomplete labels, but full lineage graph.
# If a label is missing, I should skip over it.
A = _ctc_assoc_matrix(
_labels,
_ts,
Expand Down Expand Up @@ -808,7 +820,7 @@ def _build_windows(

def __getitem__(self, n: int, return_dense=None):
# if not set, use default
if self.features == "wrfeat":
if self.features in ("wrfeat", "pretrained_feats", "pretrained_feats_aug"):
return self._getitem_wrfeat(n, return_dense)

if return_dense is None:
Expand Down Expand Up @@ -1042,6 +1054,7 @@ def _load_wrfeat(self):
self.gt_masks = self._check_dimensions(self.gt_masks)

# Load images
raw_imgs = None
if self.img_folder is None:
if self.gt_masks is not None:
self.imgs = np.zeros_like(self.gt_masks)
Expand All @@ -1050,10 +1063,12 @@ def _load_wrfeat(self):
else:
logger.info("Loading images")
imgs = self._load_tiffs(self.img_folder, dtype=np.float32)
raw_imgs = np.stack(list(imgs)) # keep raw for pretrained feature extractor
self.imgs = np.stack([
normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)
normalize(_x) for _x in tqdm(raw_imgs, desc="Normalizing", leave=False)
])
self.imgs = self._check_dimensions(self.imgs)
raw_imgs = self._check_dimensions(raw_imgs)
if self.compress:
# prepare images to be compressed later (e.g. removing non masked parts for regionprops features)
self.imgs = np.stack([
Expand Down Expand Up @@ -1107,13 +1122,40 @@ def _load_wrfeat(self):
self.det_masks[_f] = det_masks

# build features
if self.features in ("pretrained_feats", "pretrained_feats_aug"):
from trackastra_pretrained_feats import (
FeatureExtractor,
WRPretrainedFeatures,
)

features = joblib.Parallel(n_jobs=8)(
joblib.delayed(wrfeat.WRFeatures.from_mask_img)(
mask=mask[None], img=img[None], t_start=t
device = "cuda" if torch.cuda.is_available() else "cpu"
feature_extractor = FeatureExtractor.from_model_name(
self.pretrained_feats_model,
self.imgs.shape[-2:],
save_path=self.root / "embeddings",
mode=self.pretrained_feats_mode,
device=device,
additional_features=self.pretrained_feats_additional_props,
)
imgs_for_extractor = raw_imgs if raw_imgs is not None else self.imgs
feature_extractor.precompute_image_embeddings(imgs_for_extractor)
features = [
WRPretrainedFeatures.from_mask_img(
img=img[None],
mask=mask[None],
feature_extractor=feature_extractor,
t_start=t,
additional_properties=feature_extractor.additional_features,
)
for t, (mask, img) in enumerate(zip(det_masks, self.imgs))
]
else:
features = joblib.Parallel(n_jobs=8)(
joblib.delayed(wrfeat.WRFeatures.from_mask_img)(
mask=mask[None], img=img[None], t_start=t
)
for t, (mask, img) in enumerate(zip(det_masks, self.imgs))
)
for t, (mask, img) in enumerate(zip(det_masks, self.imgs))
)

properties_by_time = dict()
for _t, _feats in enumerate(features):
Expand Down Expand Up @@ -1162,7 +1204,8 @@ def _build_windows_wrfeat(
A = np.zeros((0, 0), dtype=bool)
coords = np.zeros((0, feat.ndim), dtype=int)
else:
# build matrix from incomplete labels, but full lineage graph. If a label is missing, I should skip over it.
# build matrix from incomplete labels, but full lineage graph.
# If a label is missing, I should skip over it.
A = _ctc_assoc_matrix(
labels,
timepoints,
Expand Down Expand Up @@ -1229,6 +1272,9 @@ def _getitem_wrfeat(self, n: int, return_dense=None):
coords0 = torch.from_numpy(coords0).float()
assoc_matrix = torch.from_numpy(assoc_matrix.astype(np.float32))
features = torch.from_numpy(feat.features_stacked).float()
pretrained_feats = feat.pretrained_feats
if pretrained_feats is not None:
pretrained_feats = torch.from_numpy(pretrained_feats).float()
labels = torch.from_numpy(feat.labels).long()
timepoints = torch.from_numpy(feat.timepoints).long()

Expand All @@ -1240,6 +1286,8 @@ def _getitem_wrfeat(self, n: int, return_dense=None):
coords0 = coords0[:n_elems]
features = features[:n_elems]
assoc_matrix = assoc_matrix[:n_elems, :n_elems]
if pretrained_feats is not None:
pretrained_feats = pretrained_feats[:n_elems]
logger.debug(
f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}"
)
Expand All @@ -1251,6 +1299,7 @@ def _getitem_wrfeat(self, n: int, return_dense=None):
coords = coords0.clone()
res = dict(
features=features,
pretrained_feats=pretrained_feats,
coords0=coords0,
coords=coords,
assoc_matrix=assoc_matrix,
Expand Down
1 change: 1 addition & 0 deletions trackastra/data/wrfeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def _transform_affine(k: str, v: np.ndarray, M: np.ndarray):
"intensity_max",
"intensity_min",
"border_dist",
"pretrained_feats",
):
pass
else:
Expand Down
2 changes: 0 additions & 2 deletions trackastra/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# ruff: noqa: F401

import os

from .model import TrackingTransformer
Expand Down
2 changes: 0 additions & 2 deletions trackastra/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# ruff: noqa: F401

from .track_graph import TrackGraph
from .tracking import (
build_graph,
Expand Down
2 changes: 0 additions & 2 deletions trackastra/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# ruff: noqa: F401

from .utils import (
blockwise_causal_norm,
blockwise_sum,
Expand Down
8 changes: 3 additions & 5 deletions trackastra/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def render_label(
im_img = img[..., :4]
if img.shape[-1] < 4:
im_img = np.concatenate(
[img, np.ones(img.shape[:2] + (4 - img.shape[-1],))], axis=-1
[img, np.ones((*img.shape[:2], 4 - img.shape[-1]))], axis=-1
)
else:
raise ValueError("img should be 2 or 3 dimensional")
Expand Down Expand Up @@ -378,15 +378,13 @@ def preallocate_memory(dataset, model_lightning, batch_size, max_tokens, device)
batch = dict(
features=batched(
torch.zeros(
(max_len,) + x["features"].shape[1:], dtype=x["features"].dtype
(max_len, *x["features"].shape[1:]), dtype=x["features"].dtype
),
batch_size,
device,
),
coords=batched(
torch.zeros(
(max_len,) + x["coords"].shape[1:], dtype=x["coords"].dtype
),
torch.zeros((max_len, *x["coords"].shape[1:]), dtype=x["coords"].dtype),
batch_size,
device,
),
Expand Down
Loading