diff --git a/pyproject.toml b/pyproject.toml index 9db81ac..fd5b2c0 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,19 +28,21 @@ dependencies = [ "monai", "click", "copick", + "kornia", "nibabel", "mrcfile", "starfile", + "lightning", "matplotlib", - "kornia", - "opencv-python", - "multiprocess", - "torchmetrics", - "scikit-learn", "ipywidgets", "umap-learn", "torch-ema", + "tensorboard", + "multiprocess", + "torchmetrics", + "scikit-learn", "copick-utils", + "opencv-python", ] [project.optional-dependencies] diff --git a/saber/classifier/datasets/augment.py b/saber/classifier/datasets/augment.py index 077c731..da13603 100755 --- a/saber/classifier/datasets/augment.py +++ b/saber/classifier/datasets/augment.py @@ -2,10 +2,11 @@ Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, RandRotate90d, RandFlipd, RandScaleIntensityd, RandShiftIntensityd, RandAdjustContrastd, RandGaussianNoised, RandAffined, RandomOrder, - RandGaussianSmoothd, + RandGaussianSmoothd, SqueezeDimd, ToNumpyd, EnsureTyped, ResizeD ) from saber.classifier.datasets.RandMaskCrop import AdaptiveCropd from torch.utils.data import random_split +import torch def get_preprocessing_transforms(random_translations=False): transforms = Compose([ @@ -20,13 +21,6 @@ def get_preprocessing_transforms(random_translations=False): def get_training_transforms(): train_transforms = Compose([ - # RandAffined( - # keys=["image", "mask"], - # prob=0.75, - # translate_range=(30, 30), - # padding_mode="border", - # mode=("bilinear", "nearest") - # ), RandomOrder([ RandRotate90d(keys=["image", "mask"], prob=0.5, spatial_axes=[0, 1]), RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=0), @@ -45,4 +39,28 @@ def get_validation_transforms(): def split_dataset(dataset, val_split=0.2): train_size = int(len(dataset) * (1 - val_split)) val_size = len(dataset) - train_size - return random_split(dataset, [train_size, val_size]) \ No newline at end of file + return random_split(dataset, [train_size, val_size]) + +def get_finetune_transforms(target_size=(1024,1024)): + transforms = Compose([ + EnsureChannelFirstd(keys=["image", "mask"], channel_dim="no_channel"), + EnsureTyped(keys=["image", "mask"], dtype=[torch.float32, torch.int64]), + ResizeD( + keys=["image", "mask"], + spatial_size=target_size, + mode=("bilinear", "nearest") + ), + RandomOrder([ + RandRotate90d(keys=["image", "mask"], prob=0.5, spatial_axes=[0, 1]), + RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=0), + RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=1), + RandScaleIntensityd(keys="image", prob=0.5, factors=(0.85, 1.15)), + RandShiftIntensityd(keys="image", prob=0.5, offsets=(-0.15, 0.15)), + RandAdjustContrastd(keys="image", prob=0.5, gamma=(0.85, 1.15)), + RandGaussianNoised(keys="image", prob=0.5, mean=0.0, std=1.5), + RandGaussianSmoothd(keys="image", prob=0.5, sigma_x=(0.25, 1.5), sigma_y=(0.25, 1.5)), + ]), + SqueezeDimd(keys=["image", "mask"], dim=0), + ToNumpyd(keys=["image", "mask"]), + ]) + return transforms \ No newline at end of file diff --git a/saber/classifier/preprocess/split_merge_data.py b/saber/classifier/preprocess/split_merge_data.py index 5090833..d00cd8d 100644 --- a/saber/classifier/preprocess/split_merge_data.py +++ b/saber/classifier/preprocess/split_merge_data.py @@ -1,9 +1,37 @@ from sklearn.model_selection import train_test_split from typing import List, Tuple, Dict, Optional +from zarr.convenience import copy as zarr_copy from pathlib import Path import click, zarr, os import numpy as np +def copy_like(src_arr, dst_group, path: str): + # Ensure parent groups for nested paths like "labels/0" + parent = dst_group + parts = path.split('/') + for p in parts[:-1]: + parent = parent.require_group(p) + + leaf = parts[-1] + # Create (or reuse) the dst array with identical metadata + dst_arr = parent.require_dataset( + leaf, + shape=src_arr.shape, + dtype=src_arr.dtype, + chunks=src_arr.chunks, + compressor=src_arr.compressor, + filters=src_arr.filters, + order=src_arr.order, + fill_value=src_arr.fill_value, + # Optional (if your sources use it): + **({"dimension_separator": getattr(src_arr, "dimension_separator", None)} + if hasattr(src_arr, "dimension_separator") else {}) + ) + dst_arr[:] = src_arr[:] # fast data copy + # Preserve array attrs + if src_arr.attrs: + dst_arr.attrs.update(src_arr.attrs) + def split( input: str, ratio: float, @@ -57,19 +85,25 @@ def split( items = ['0', 'labels/0', 'labels/rejected'] print('Copying data to train zarr file...') for key in train_keys: - train_zarr.create_group(key) # Explicitly create the group first - copy_attributes(zfile[key], train_zarr[key]) + dst_grp = train_zarr.require_group(key) + copy_attributes(zfile[key], dst_grp) for item in items: - train_zarr[key][item] = zfile[key][item][:] # [:] ensures a full copy - copy_attributes(zfile[key]['labels'], train_zarr[key]['labels']) + try: + copy_like(zfile[key][item], dst_grp, item) + copy_attributes(zfile[key][item], dst_grp[item]) + except Exception as e: + pass print('Copying data to validation zarr file...') for key in val_keys: - val_zarr.create_group(key) # Explicitly create the group first - copy_attributes(zfile[key], val_zarr[key]) + dst_grp = val_zarr.require_group(key) + copy_attributes(zfile[key], dst_grp) for item in items: - val_zarr[key][item] = zfile[key][item][:] # [:] ensures a full copy - copy_attributes(zfile[key]['labels'], val_zarr[key]['labels']) + try: + copy_like(zfile[key][item], dst_grp, item) + copy_attributes(zfile[key][item], dst_grp[item]) + except Exception as e: + pass # Print summary print(f"\nSplit Summary:") @@ -124,18 +158,16 @@ def merge(inputs: List[str], output: str): write_key = session_label + '_' + key # Create the group and copy its attributes - new_group = mergedZarr.create_group(write_key) # Explicitly create the group first - copy_attributes(zfile[key], new_group) + dst_grp = mergedZarr.require_group(write_key) + copy_attributes(zfile[key], dst_grp) # Copy the data arrays for item in items: try: - # [:] ensures a full copy - mergedZarr[write_key][item] = zfile[key][item][:] + copy_like(zfile[key][item], dst_grp, item) + copy_attributes(zfile[key][item], dst_grp[item]) except Exception as e: pass - # Copy attributes for labels subgroup - copy_attributes(zfile[key]['labels'], new_group['labels']) # Copy all attributes from the last input zarr file for attr_name, attr_value in zfile.attrs.items(): @@ -194,8 +226,5 @@ def copy_attributes(source, destination): """ if hasattr(source, 'attrs') and source.attrs: destination.attrs.update(source.attrs) - -if __name__ == '__main__': - cli() diff --git a/saber/finetune/README.md b/saber/finetune/README.md new file mode 100644 index 0000000..e69de29 diff --git a/saber/finetune/__init__.py b/saber/finetune/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/saber/finetune/abiou.py b/saber/finetune/abiou.py new file mode 100644 index 0000000..9b73d50 --- /dev/null +++ b/saber/finetune/abiou.py @@ -0,0 +1,351 @@ +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from typing import List, Dict, Any, Optional, Callable, Union +from contextlib import nullcontext +import numpy as np +import torch + +# --------------------- IoU / ABIoU --------------------- + +def _mask_iou(a, b, eps=1e-6): + """ + a: [Na,H,W] {0,1}; b: [Nb,H,W] {0,1} -> IoU [Na,Nb] + """ + if a.numel() == 0 or b.numel() == 0: + dev = a.device if a.numel() > 0 else (b.device if b.numel() > 0 else torch.device("cpu")) + Na = a.shape[0] if a.numel() > 0 else 0 + Nb = b.shape[0] if b.numel() > 0 else 0 + return torch.zeros((Na, Nb), device=dev, dtype=torch.float32) + + a = a.float() + b = b.float() + inter = torch.einsum("nhw,mhw->nm", a, b) # [Na,Nb] + + ua = a.sum(dim=(1,2))[:, None] # [Na,1] + ub = b.sum(dim=(1,2))[None, :] # [1,Nb] + + union = ua + ub - inter + eps # [Na,Nb] + return inter / union + +def _abiou(proposals, gts): + """ + Average Best IoU (coverage metric). + proposals: [Np,H,W] {0,1}, gts: [Ng,H,W] {0,1} + """ + if gts.numel() == 0 and proposals.numel() == 0: + dev = proposals.device if proposals.numel() > 0 else gts.device + return torch.tensor(1.0, device=dev, dtype=torch.float32) + if gts.numel() == 0 or proposals.numel() == 0: + dev = proposals.device if proposals.numel() > 0 else gts.device + return torch.tensor(0.0, device=dev, dtype=torch.float32) + iou = _mask_iou(gts, proposals) # [Ng,Np] + best = iou.max(dim=1).values # [Ng] + return best.mean() + +# --------------------- Utilities --------------------- + +def _to_bool_tensor(x: Union[np.ndarray, torch.Tensor], device: torch.device) -> torch.Tensor: + """ + Accepts HxW or [N,H,W] arrays/tensors, binarizes (>0) and returns bool tensor on device with shape [N,H,W]. + """ + if isinstance(x, np.ndarray): + arr = x + if arr.ndim == 2: + arr = arr[None, ...] + t = torch.from_numpy(arr) + elif isinstance(x, torch.Tensor): + t = x + if t.ndim == 2: + t = t.unsqueeze(0) + else: + raise TypeError(f"Unsupported mask type: {type(x)}") + # binarize and cast to bool + t = (t != 0) + return t.to(device=device, dtype=torch.bool) + + +def _downsample_bool_masks(m: torch.Tensor, factor: int) -> torch.Tensor: + """ + Downsample boolean masks by a small integer factor via max-pooling (keeps foreground coverage). + m: [N,H,W] bool + """ + if factor <= 1 or m.numel() == 0: + return m + # reshape for pooling + N, H, W = m.shape + H2 = H // factor + W2 = W // factor + if H2 == 0 or W2 == 0: + return m + # crop to divisible + m = m[:, :H2 * factor, :W2 * factor] + # convert to float for pooling-like reduction via unfold + mf = m.float() + mf = mf.unfold(1, factor, factor).unfold(2, factor, factor) # [N, H2, W2, f, f] + # max over the small window -> any(True) + mf = mf.contiguous().view(N, H2, W2, -1).max(dim=-1).values + return (mf > 0).to(dtype=torch.bool) + + +# --------------------- IoU (vectorized) --------------------- + +def _pairwise_iou_bool(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + a: [Na,H,W] bool, b: [Nb,H,W] bool -> IoU [Na,Nb] float32 + Vectorized via flatten + matmul. Stays on device. + """ + Na = a.shape[0] + Nb = b.shape[0] + if Na == 0 or Nb == 0: + return a.new_zeros((Na, Nb), dtype=torch.float32) + + # Flatten (reshape tolerates non-contiguous inputs) + a_f = a.reshape(Na, -1).float() + b_f = b.reshape(Nb, -1).float() + + # Areas and intersections + areas_a = a_f.sum(dim=1) # [Na] + areas_b = b_f.sum(dim=1) # [Nb] + inter = a_f @ b_f.t() # [Na,Nb] + + # Unions + union = areas_a[:, None] + areas_b[None, :] - inter + eps + return (inter / union).to(torch.float32) + + +# --------------------- Metrics --------------------- + +def abiou_original(proposals: torch.Tensor, gts: torch.Tensor) -> torch.Tensor: + """ + ABIoU as mean over GTs of max IoU to any proposal (allows proposal reuse). + proposals, gts: [N,H,W] bool (on same device) + """ + dev = gts.device if gts.numel() > 0 else (proposals.device if proposals.numel() > 0 else torch.device("cpu")) + if gts.numel() == 0 and proposals.numel() == 0: + return torch.tensor(1.0, device=dev, dtype=torch.float32) + if gts.numel() == 0 or proposals.numel() == 0: + return torch.tensor(0.0, device=dev, dtype=torch.float32) + + iou = _pairwise_iou_bool(gts, proposals) # [Ng,Np] + best = iou.max(dim=1).values # [Ng] + return best.mean() + + +def abiou_one_to_one_greedy(proposals: torch.Tensor, gts: torch.Tensor) -> torch.Tensor: + """ + ABIoU with one-to-one greedy matching (no proposal reuse). + proposals, gts: [N,H,W] bool + """ + dev = gts.device if gts.numel() > 0 else (proposals.device if proposals.numel() > 0 else torch.device("cpu")) + if gts.numel() == 0 and proposals.numel() == 0: + return torch.tensor(1.0, device=dev, dtype=torch.float32) + if gts.numel() == 0 or proposals.numel() == 0: + return torch.tensor(0.0, device=dev, dtype=torch.float32) + + iou = _pairwise_iou_bool(gts, proposals) # [Ng,Np] + Ng, Np = iou.shape + used_g = torch.zeros(Ng, dtype=torch.bool, device=dev) + used_p = torch.zeros(Np, dtype=torch.bool, device=dev) + + matched_sum = torch.tensor(0.0, device=dev) + # Greedy loop at most min(Ng, Np) steps + for _ in range(min(Ng, Np)): + # mask used rows/cols by setting to -1 + iou_masked = iou.clone() + if used_g.any(): + iou_masked[used_g, :] = -1 + if used_p.any(): + iou_masked[:, used_p] = -1 + val, idx = torch.max(iou_masked.view(-1), dim=0) + if val <= 0: + break + g_idx = idx // Np + p_idx = idx % Np + matched_sum = matched_sum + val + used_g[g_idx] = True + used_p[p_idx] = True + + # Average over ALL GTs (unmatched GTs count 0) + return matched_sum / max(Ng, 1) + + +def union_iou(proposals: torch.Tensor, gts: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + Pixel-set IoU between union of proposals and union of GTs. + proposals, gts: [N,H,W] bool + """ + dev = gts.device if gts.numel() > 0 else (proposals.device if proposals.numel() > 0 else torch.device("cpu")) + if gts.numel() == 0 and proposals.numel() == 0: + return torch.tensor(1.0, device=dev, dtype=torch.float32) + + if proposals.numel() > 0: + P = proposals.any(dim=0) + else: + # create zero map based on gts ref + H, W = gts.shape[-2], gts.shape[-1] + P = torch.zeros((H, W), dtype=torch.bool, device=dev) + + if gts.numel() > 0: + G = gts.any(dim=0) + else: + H, W = proposals.shape[-2], proposals.shape[-1] + G = torch.zeros((H, W), dtype=torch.bool, device=dev) + + inter = (P & G).sum().float() + uni = (P | G).sum().float() + eps + return (inter / uni).to(torch.float32) + + +# --------------------- Main evaluator --------------------- + +@torch.no_grad() +def automask_metrics( + sam2_model_or_predictor: Any, + images: List[Union[np.ndarray, torch.Tensor]], # HxW or HxWx3 (uint8 preferred) + gt_masks_per_image: List[List[Union[np.ndarray, torch.Tensor]]], # per-image list of HxW masks + *, + amg_kwargs: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = 200, + device: Optional[torch.device] = None, + autocast_ctx: Optional[Callable[[], Any]] = None, + downsample_factor: int = 1, + return_per_image: bool = False, +) -> Dict[str, Any]: + """ + Run SAM2AutomaticMaskGenerator once per image and compute: + - ABIoU_one_to_one (greedy, no reuse) + - UnionIoU + - ABIoU_original (optional reference) + + Speed features: + - Single IoU matrix fuels both ABIoUs. + - Everything stays on GPU; masks are boolean. + - Optional downsample_factor (e.g., 2 or 4) for huge speedups. + + Returns: + { + 'ABIoU_one_to_one': float, + 'UnionIoU': float, + 'ABIoU_original': float, + 'num_images': int, + 'per_image': [ ... ] # if return_per_image + } + """ + + # AMG defaults (safe, tweak as needed) + _amg = dict( + points_per_side=32, + points_per_batch=128, + pred_iou_thresh=0.7, + stability_score_thresh=0.92, + stability_score_offset=0.7, + crop_n_layers=1, + crop_n_points_downscale_factor=2, + box_nms_thresh=0.7, + use_m2m=False, + multimask_output=True, + ) + if amg_kwargs: + _amg.update(amg_kwargs) + + model_for_amg = getattr(sam2_model_or_predictor, "model", sam2_model_or_predictor) + mask_generator = SAM2AutomaticMaskGenerator(model=model_for_amg, **_amg) + + # Device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Autocast (AMG forward only) + ac = autocast_ctx if autocast_ctx is not None else (lambda: nullcontext()) + + # Accumulators + one2one_vals, union_vals, abiou_orig_vals = [], [], [] + per_image_out = [] + + for img, gt_list in zip(images, gt_masks_per_image): + # ---- Ensure numpy uint8 image for AMG ---- + if isinstance(img, torch.Tensor): + img_np = img.detach().cpu().numpy() + else: + img_np = img + H, W = img_np.shape[:2] + + # ---- AMG forward ---- + with ac(): + proposals = mask_generator.generate(img_np) # list of dict + + # ---- Convert proposals -> [Np,H,W] bool on device ---- + if len(proposals) == 0: + prop_masks = torch.zeros((0, H, W), dtype=torch.bool, device=device) + else: + # sort by predicted_iou (or score), keep top_k + def _score(d): + return float(d.get("predicted_iou", d.get("score", 0.0))) + proposals.sort(key=_score, reverse=True) + if top_k is not None and top_k > 0: + proposals = proposals[:top_k] + masks_np = [(p["segmentation"] > 0).astype(np.uint8) for p in proposals] + prop_masks = torch.from_numpy(np.stack(masks_np, axis=0)).to(device=device, dtype=torch.bool) + prop_masks = prop_masks.contiguous() + + # ---- Convert GTs -> [Ng,H,W] bool on device ---- + if len(gt_list) == 0: + gt_masks = torch.zeros((0, H, W), dtype=torch.bool, device=device) + else: + gt_bool = [] + for g in gt_list: + if isinstance(g, torch.Tensor): + g_np = g.detach().cpu().numpy() + else: + g_np = g + gt_bool.append((g_np > 0).astype(np.uint8)) + gt_masks = torch.from_numpy(np.stack(gt_bool, axis=0)).to(device=device, dtype=torch.bool) + gt_masks = gt_masks.contiguous() + + # ---- Optional downsample (max-pool style) ---- + if downsample_factor > 1: + prop_masks_ds = _downsample_bool_masks(prop_masks, downsample_factor) + gt_masks_ds = _downsample_bool_masks(gt_masks, downsample_factor) + else: + prop_masks_ds = prop_masks + gt_masks_ds = gt_masks + + # ---- Metrics (single IoU matrix shared under the hood) ---- + m_one2one = abiou_one_to_one_greedy(prop_masks_ds, gt_masks_ds) + m_union = union_iou(prop_masks_ds, gt_masks_ds) + m_orig = abiou_original(prop_masks_ds, gt_masks_ds) + + one2one_vals.append(m_one2one) + union_vals.append(m_union) + abiou_orig_vals.append(m_orig) + + if return_per_image: + per_image_out.append({ + "ABIoU": float(m_one2one.detach().cpu()), + "ABIoU_original": float(m_orig.detach().cpu()), + "num_props": int(prop_masks.shape[0]), + "num_gt": int(gt_masks.shape[0]), + "H": int(H), + "W": int(W), + }) + + # ---- Averages ---- + if len(one2one_vals) == 0: + return { + "ABIoU_one_to_one": 0.0, + "UnionIoU": 0.0, + "ABIoU_original": 0.0, + "num_images": 0, + "per_image": [], + } + + ABIoU_one_to_one = torch.stack(one2one_vals).mean().item() + ABIoU_original_avg = torch.stack(abiou_orig_vals).mean().item() + + out = { + "ABIoU": ABIoU_one_to_one, + "ABIoU_original": ABIoU_original_avg, + "num_images": len(one2one_vals), + } + if return_per_image: + out["per_image"] = per_image_out + return out diff --git a/saber/finetune/cli.py b/saber/finetune/cli.py new file mode 100644 index 0000000..8c9aa8b --- /dev/null +++ b/saber/finetune/cli.py @@ -0,0 +1,12 @@ +from saber.finetune.train import finetune +from saber.finetune.prep import main as prep +import click + +@click.group(name="finetune") +def finetune_routines(): + """Routines for finetuning SAM2 on New Modalities.""" + pass + +# Add subcommands to the group +finetune_routines.add_command(finetune) +finetune_routines.add_command(prep) diff --git a/saber/finetune/dataset.py b/saber/finetune/dataset.py new file mode 100644 index 0000000..b6207b7 --- /dev/null +++ b/saber/finetune/dataset.py @@ -0,0 +1,471 @@ +from scipy.ndimage import binary_erosion, binary_dilation +from monai.transforms import Compose, EnsureChannelFirstd +from typing import Optional, Tuple, Union +import saber.finetune.helper as helper +from saber.utils import preprocessing +from torch.utils.data import Dataset +import zarr, torch, random +from tqdm import tqdm +import numpy as np + +class AutoMaskDataset(Dataset): + def __init__(self, + tomogram_zarr_path: str = None, + fib_zarr_path: str = None, + transform = None, + num_slabs: int = 50, + num_slices: int = 50, + slab_thickness: int = 5, + seed: int = 42, + shuffle: bool = True): + """ + Args: + tomogram_zarr_path: Path to the tomogram zarr store + fib_zarr_path: Path to the fib zarr store + transform: Transform to apply to the data + num_slabs: Number of slabs per volume per epoch + num_slices: Number of slices per fib per epoch + slab_thickness: Thickness of the slab + """ + + # Slabs per Epoch + self.slab_thickness = slab_thickness + self.num_slabs = num_slabs + self.num_slices = num_slices + + # Grid and Positive Points for AutoMaskGenerator + self.points_per_side = 32 + self.min_pixels = 1e3 + self.k_min = 50 + self.k_max = 100 + self.transform = transform + self.keep_fraction = 0.5 + self.shuffle = shuffle + self.max_pos_pts = 8 + self.K_choices = [2, 2, 4, 4, 8] + + # Check if both data types are available + if tomogram_zarr_path is None and fib_zarr_path is None: + raise ValueError("At least one of tomogram_zarr_path or fib_zarr_path must be provided") + + # Flags to track which data types are available + self.has_tomogram = tomogram_zarr_path is not None + self.has_fib = fib_zarr_path is not None + + # Initialize tomogram data if provided + if self.has_tomogram: + self.tomogram_store = zarr.open(tomogram_zarr_path, mode='r') + self.tomogram_keys = [k for k in self.tomogram_store.keys() if not k.startswith('.')] + self.tomo_shapes = {} + for i, key in tqdm(enumerate(self.tomogram_keys), + total=len(self.tomogram_keys), desc="Estimating zrange for tomograms"): + try: + self.tomo_shapes[i] = self._estimate_zrange(key) + except Exception as e: + print(f"Error estimating zrange for tomogram {key}: {e}") + # remove key from tomogram_keys + self.tomogram_keys.remove(key) + self.n_tomogram_volumes = len(self.tomogram_keys) + else: + self.n_tomogram_volumes = 0 + self.tomo_shapes = {} + self.tomogram_keys = [] + + # Initialize fib data if provided + if self.has_fib: + self.fib_store = zarr.open(fib_zarr_path, mode='r') + self.fib_keys = [k for k in self.fib_store.keys() if not k.startswith('.')] + self.n_fib_volumes = len(self.fib_keys) + self.fib_shapes = {} + for i, key in enumerate(self.fib_keys): + try: + self.fib_shapes[i] = self._estimate_zrange(key, source='fib') + except Exception as e: + print(f"Error estimating zrange for fib {key}: {e}") + # remove key from fib_keys + self.fib_keys.remove(key) + self.n_fib_volumes = len(self.fib_keys) + else: + self.n_fib_volumes = 0 + self.fib_shapes = {} + self.fib_keys = [] + + # Random seed + self.seed = seed + self._rng = np.random.RandomState(seed) + + # Verbose Flag + self.verbose = False + + # Samples + self.tomogram_samples = [] + self.fib_samples = [] + self._prev_tomogram_samples = [] + self._prev_fib_samples = [] + + # First sampling epoch + self.resample_epoch() + + def _estimate_zrange(self, key, band=(0.3, 0.7), source = 'tomo', threshold=0): + """ + Returns (z_min, z_max) inclusive bounds for valid slab centers + where there is some foreground in the labels. + - threshold: min # of fg pixels to count a slice as non-empty (0 = any) + - band: fraction of Z to consider (lo, hi) + """ + + if source == 'tomo': + nz = self.tomogram_store[key]['0'].shape[0] + elif source == 'fib': + nz = self.fib_store[key]['0'].shape[0] + min_offset, max_offset = int(nz * band[0]), int(nz * band[1]) + if source == 'tomo': + mask = self.tomogram_store[key]['labels/0'][min_offset:max_offset,] + elif source == 'fib': + mask = self.fib_store[key]['labels/0'][min_offset:max_offset,] + vals = mask.sum(axis=(1,2)) + vals = np.nonzero(vals)[0] + max_val = vals.max() - self.slab_thickness + min_offset + min_val = vals.min() + self.slab_thickness + min_offset + return int(min_val), int(max_val) + + def _compute_indices( + self, + D: int, + N: int, + ) -> np.ndarray: + """ + Precompute N slice indices for a volume of depth D. + + Args: + D: total #slices (depth). + N: total indices to return. + center_bias: fraction sampled from a Gaussian; the rest uniform. + sigma: std-dev for the Gaussian as a fraction of the (bounded) depth. + z_bounds: optional (z_min, z_max). If ints, treated as absolute slice + indices (inclusive). If floats in [0,1], treated as fractions + of D (e.g., (0.1, 0.9)). Indices are clamped into [0, D-1]. + + Returns: + np.ndarray of shape (N,) with integer slice indices in ascending order. + """ + assert D > 0 and N > 0 + + center_bias = 0.9 + sigma = 0.2 + z0, z1 = 0, D - 1 + + # --- Resolve bounds --- + if z1 < z0: + # empty window → fall back to full range + z0, z1 = 0, D - 1 + + # Window length (inclusive) + W = (z1 - z0 + 1) + if W <= 0: + z0, z1 = 0, D - 1 + W = D + + # --- How many from each sampler --- + n_g = int(round(center_bias * N)) + n_u = N - n_g + + # --- Gaussian over the bounded window --- + mu = z0 + 0.5 * (W - 1) + s = sigma * (W - 1) + if s <= 0: + # degenerate window or sigma=0 → just center + g = np.full(n_g, int(round(mu)), dtype=int) + else: + g = self._rng.normal(mu, s, size=n_g) + g = np.clip(np.round(g).astype(int), z0, z1) + + # --- Uniform over the bounded window --- + u = self._rng.randint(z0, z1 + 1, size=n_u, dtype=int) + + # --- Merge, unique, and pad if needed --- + idx = np.unique(np.concatenate([g, u])) + # If de-dup removed too many, top up with uniform draws + while idx.size < N: + extra = self._rng.randint(z0, z1 + 1, size=N - idx.size, dtype=int) + idx = np.unique(np.concatenate([idx, extra])) + + # Return exactly N indices (sorted for reproducibility) + return np.sort(idx[:N]) + + def resample_epoch(self): + """ Generate new random samples for this epoch """ + + # Sample random slabs from each tomogram + if self.has_tomogram: + print(f"Re-Sampling {self.num_slabs} slabs from {self.n_tomogram_volumes} tomograms") + new_tomo_samples = [] + for vol_idx in tqdm(range(self.n_tomogram_volumes), desc="Sampling slabs from tomograms"): + + # Sample Indices in [0, D-1], center-biased within the *window* + z_min, z_max = self.tomo_shapes[vol_idx] + D = (z_max - z_min + 1) + z_rel = self._compute_indices(D, self.num_slabs) + + # Shift to absolute z and add to samples + z_positions = z_rel + z_min + for z_pos in z_positions: + new_tomo_samples.append((vol_idx, z_pos)) + self.tomogram_samples = new_tomo_samples + + # Shuffle samples + if self.shuffle: self._rng.shuffle(self.tomogram_samples) + + # Sample random slices from each FIB volume + if self.has_fib: + print(f"Re-Sampling {self.num_slices} slices from {self.n_fib_volumes} FIB volumes") + new_fib_samples = [] + for fib_idx in tqdm(range(self.n_fib_volumes), desc="Sampling slices from FIB volumes"): + # Sample random z positions from this FIB volume + z_min, z_max = self.fib_shapes[fib_idx] + D = (z_max - z_min + 1) + z_rel = self._compute_indices( + D, self.num_slices ) + # Shift to absolute z and add to samples + z_positions = z_rel + z_min + for z_pos in z_positions: + new_fib_samples.append((fib_idx, z_pos)) + self.fib_samples = new_fib_samples + + if self.shuffle: self._rng.shuffle(self.fib_samples) # Shuffle samples + + # Set epoch length + self.epoch_length = len(self.tomogram_samples) + len(self.fib_samples) + + def __len__(self): + return self.epoch_length + + def __getitem__(self, idx): + + # Get item from tomogram or FIB + if idx < len(self.tomogram_samples) and self.has_tomogram: + return self._get_tomogram_item(idx) + else: + return self._get_fib_item(idx - len(self.tomogram_samples)) + + def _get_tomogram_item(self, idx): + + # Randomly select a tomogram volume + vol_idx, z_pos = self.tomogram_samples[idx] + key = self.tomogram_keys[vol_idx] + + # Load image and segmentation slab + z_start = z_pos - self.slab_thickness // 2 + z_end = z_pos + self.slab_thickness // 2 + 1 + image_slab = self.tomogram_store[key]['0'][z_start:z_end] + seg_slab = self.tomogram_store[key]['labels/0'][z_start:z_end] + + # Project slab and segmentation + image_2d = preprocessing.project_tomogram(image_slab) + seg_2d = preprocessing.project_segmentation(seg_slab) # HxW + + return self._package_image_item(image_2d, seg_2d) + + def _get_fib_item(self, idx): + + # Randomly select a FIB volume + fib_idx, z_pos = self.fib_samples[idx] + key = self.fib_keys[fib_idx] + + # Load FIB image and segmentation + image = self.fib_store[key]['0'][z_pos,].astype(np.float32) # HxW + seg_2d = self.fib_store[key]['labels/0'][z_pos,] # HxW + return self._package_image_item(image, seg_2d) + + def _gen_grid_points(self, h: int, w: int) -> np.ndarray: + """ + Generate grid points for a given image size + """ + xs = np.linspace(0.5, w - 0.5, self.points_per_side, dtype=np.float32) + ys = np.linspace(0.5, h - 0.5, self.points_per_side, dtype=np.float32) + xx, yy = np.meshgrid(xs, ys) + return np.stack([xx.ravel(), yy.ravel()], axis=1) # (G,2) as (x,y) + + def _sample_negative_ring(self, comp, other_inst=None, ring=3, max_neg=16, shape=None): + h, w = shape + comp_b = comp.astype(np.bool_) + outer = binary_dilation(comp_b, iterations=ring) & (~comp_b) + if other_inst is not None: + # avoid putting negatives inside any other instance + outer = outer & (~other_inst.astype(np.bool_)) + ys, xs = np.where(outer) + if len(xs) == 0: + return np.zeros((0, 2), np.float32) + k = min(max_neg, len(xs)) + idx = self._rng.choice(len(xs), size=k, replace=False) + return np.stack([xs[idx].astype(np.float32), ys[idx].astype(np.float32)], axis=1) + + def _sample_points_in_mask( + self, + comp: np.ndarray, + grid_points: np.ndarray, + shape: tuple[int, int], + jitter_px: float = 1.0, + k_cap: int = 100, + boundary_frac: float = 0.35, + ) -> np.ndarray: + """ + Pick informative clicks from a dense grid: + - favor boundary points + - cap count ~sqrt(area) up to k_cap + Returns float32 array [K,2] in (x,y). + """ + h, w = shape + if comp.sum() == 0: + return np.zeros((0, 2), np.float32) + + # ----- cast to boolean for morphology ----- + comp_b = comp.astype(np.bool_) # important: avoid TypeError with '^' + + # grid → nearest pixel indices for inside/boundary tests + gx = np.clip(np.rint(grid_points[:, 0]).astype(int), 0, w - 1) + gy = np.clip(np.rint(grid_points[:, 1]).astype(int), 0, h - 1) + + inside = comp_b[gy, gx] + cand = grid_points[inside] + if cand.shape[0] == 0: + return np.zeros((0, 2), np.float32) + + # ----- boundary mask (inner ring) ----- + eroded = binary_erosion(comp_b, iterations=2) + boundary_b = np.logical_and(comp_b, np.logical_not(eroded)) # same as comp ^ eroded on booleans + + on_b = boundary_b[gy, gx] & inside + cand_b = grid_points[on_b] + cand_i = grid_points[inside & (~on_b)] + + # target k ~ c * area but capped + area = float(comp_b.sum()) + k_target = int( + min( k_cap, max(12, area * 0.1) ) + ) + + kb = int(boundary_frac * k_target) + ki = k_target - kb + + rng = self._rng + take_b = min(kb, len(cand_b)) + take_i = min(ki, len(cand_i)) + + # if boundary is too small, backfill from interior + if take_b + take_i == 0: + return np.zeros((0, 2), np.float32) + + if take_b: + idx_b = rng.choice(len(cand_b), size=take_b, replace=False) + pts_b = cand_b[idx_b] + else: + pts_b = np.zeros((0, 2), np.float32) + + if take_i: + idx_i = rng.choice(len(cand_i), size=take_i, replace=False) + pts_i = cand_i[idx_i] + else: + pts_i = np.zeros((0, 2), np.float32) + + pts = np.concatenate([pts_b, pts_i], axis=0).astype(np.float32) + + # jitter a touch to avoid perfect grid regularity + if jitter_px > 0 and pts.shape[0] > 0: + jitter = rng.uniform(-jitter_px, jitter_px, size=pts.shape).astype(np.float32) + pts += jitter + pts[:, 0] = np.clip(pts[:, 0], 0, w - 1) + pts[:, 1] = np.clip(pts[:, 1], 0, h - 1) + + return pts + + def _package_image_item(self, + image_2d: np.ndarray, + segmentation: np.ndarray): + """ + Build per-component targets using grid-hit instances. + - Splits each hit instance into connected components. + - Drops components smaller than min_area_frac * (H*W). + - Emits only positive clicks + boxes (no negatives). + Returns: + { + "image": HxW, + "masks": list[H x W] float32 in {0,1}, + "points": list[#p x 2] float32 (xy), + "labels": list[#p] float32 (all ones), + "boxes": list[4] float32 (x0,y0,x1,y1) + } + """ + + # Apply transforms to image and segmentation + if self.transform: + sample = self.transform({'image': image_2d, 'mask': segmentation}) + image_2d, segmentation = sample['image'], sample['mask'] + + # Get image and segmentation shapes + h, w = segmentation.shape + + # which instances to train on for this image + grid_points = self._gen_grid_points(h, w) + inst_ids = helper.instances_from_grid(grid_points, segmentation) + masks_t, points_t, labels_t, boxes_t = [], [], [], [] + + for iid in inst_ids: + comps = helper.components_for_id(segmentation, iid, self.min_pixels) + for comp in comps: + # box from this component + box = helper.mask_to_box(comp) + if box is None: + continue + + K = self.K_choices[self._rng.randint(0, len(self.K_choices))] + pts_pos = self._sample_points_in_mask(comp, grid_points, (h, w)) + if pts_pos.shape[0] < 3: + continue + elif pts_pos.shape[0] > K: + sel = self._rng.choice(pts_pos.shape[0], size=K, replace=False) + pts_pos = pts_pos[sel] + + # negatives just outside this component, but not inside any *other* instance + use_negs = (self._rng.rand() < 0.5) + if use_negs: + other_inst = ((segmentation > 0) & (~comp.astype(bool))) + neg_ring = self._sample_negative_ring( + comp, other_inst=other_inst, ring=3, + max_neg=min(4, int(0.25 * len(pts_pos))), shape=(h, w) + ) + + pts = np.concatenate([pts_pos, neg_ring], 0) + lbl = np.concatenate([np.ones(len(pts_pos), np.float32), + np.zeros(len(neg_ring), np.float32)], 0) + else: + pts, lbl = pts_pos, np.ones((len(pts_pos),), np.float32) + + # shuffle to avoid positional bias + if pts.shape[0] > 1: + order = self._rng.permutation(pts.shape[0]) + pts, lbl = pts[order], lbl[order] + + masks_t.append(torch.from_numpy(comp.astype(np.float32))) + points_t.append(torch.from_numpy(pts.astype(np.float32))) + labels_t.append(torch.from_numpy(lbl.astype(np.float32))) # 1=pos, 0=neg + boxes_t.append(torch.from_numpy(box.astype(np.float32))) + + # fallback to a harmless dummy if nothing was hit by the grid (keeps loader stable) + if len(masks_t) == 0: + print("No masks found") + masks_t = [torch.from_numpy(np.zeros_like(segmentation, dtype=np.float32))] + points_t = [torch.from_numpy(np.zeros((1, 2), dtype=np.float32))] + labels_t = [torch.from_numpy(np.ones((1,), dtype=np.float32))] + boxes_t = [torch.from_numpy(np.array([0, 0, 1, 1], dtype=np.float32))] + + # Normalize the Image + image_2d = preprocessing.proprocess(image_2d) # 3xHxW + + return { + "image": image_2d, # HxWx3 + "masks": masks_t, # list[H x W] float32 in {0,1} + "points": points_t, # list[#p x 2] float32 (xy) + "labels": labels_t, # list[#p] all ones + "boxes": boxes_t, # list[4] (x0,y0,x1,y1) + } \ No newline at end of file diff --git a/saber/finetune/helper.py b/saber/finetune/helper.py new file mode 100644 index 0000000..e58e5d0 --- /dev/null +++ b/saber/finetune/helper.py @@ -0,0 +1,384 @@ +from saber.visualization.classifier import get_colors, add_masks +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import cv2, torch, csv, os +import numpy as np + +def mask_to_box(mask: np.ndarray) -> np.ndarray | None: + """xyxy box from a binary mask (H,W) in {0,1}.""" + ys, xs = np.where(mask > 0) + if xs.size == 0: + return None + return np.array([xs.min(), ys.min(), xs.max(), ys.max()], dtype=np.float32) + +def sample_positive_points(mask: np.ndarray, k: int = 1) -> np.ndarray: + """Sample k (x,y) positives uniformly from mask pixels.""" + ys, xs = np.where(mask > 0) + if xs.size == 0: + return np.zeros((0, 2), dtype=np.float32) + idx = np.random.randint(0, xs.size, size=k) + return np.stack([xs[idx], ys[idx]], axis=1).astype(np.float32) + +def instances_from_grid(grid_points: np.ndarray, segmentation: np.ndarray) -> list[int]: + """Return unique non-zero instance IDs at the given (x,y) grid sample points.""" + h, w = segmentation.shape + xs = np.clip(grid_points[:, 0].astype(np.int32), 0, w - 1) + ys = np.clip(grid_points[:, 1].astype(np.int32), 0, h - 1) + ids = segmentation[ys, xs] + ids = ids[ids > 0] + return np.unique(ids).astype(int).tolist() + +# helper: split an instance id into connected components +def components_for_id(seg, iid: int, min_pixels: int): + mask_u8 = (seg == iid).astype(np.uint8) # 0/1 + num, lbl = cv2.connectedComponents(mask_u8) + comps = [] + for cid in range(1, num): # skip background=0 + comp = (lbl == cid).astype(np.float32) # HxW {0,1} + if comp.sum() >= min_pixels: + comps.append(comp) + return comps + + +def collate_autoseg(batch, max_res: int = 1024): + """ + Aspect-preserving resize to fit within max_res, then pad each sample + to the batch's (H_max, W_max). Keep ragged structures (points/labels/boxes) + as lists per instance. + """ + processed = [_resize_one(sample, max_res) for sample in batch] + + # Common padded size for this batch + H_max = max(s["image"].shape[0] for s in processed) + W_max = max(s["image"].shape[1] for s in processed) + + for s in processed: + h, w = s["image"].shape[:2] + + # pad image (top-left anchoring) + pad_img = np.zeros((H_max, W_max, 3), dtype=s["image"].dtype) + pad_img[:h, :w] = s["image"] + s["image"] = pad_img + + # pad masks + padded_masks = [] + for m in s["masks"]: + pm = np.zeros((H_max, W_max), dtype=np.uint8) + pm[:h, :w] = m + padded_masks.append(pm) + s["masks"] = padded_masks + + # NOTE: points/boxes coords don't change with top-left padding + + # Return exactly what your trainer expects: lists of per-sample items, + # and for ragged things, lists of per-instance tensors. + return { + "images": [s["image"] for s in processed], # list of HxWx3 uint8 (predictor handles numpy) + "masks": [torch.from_numpy(np.stack(s["masks"])) for s in processed], # list of [M,H,W] + "points": [ [torch.from_numpy(p) for p in s["points"]] for s in processed], # list of list[[Pi,2]] + "labels": [ [torch.as_tensor(l) for l in s["labels"]] for s in processed], # list of list[[Pi]] + "boxes": [ [torch.from_numpy(b) for b in s["boxes"]] for s in processed], # list of list[[4]] + } + +def _resize_one(s, max_res: int): + """ + Resize one sample with aspect preserved. Scale coords per instance. + Keeps ragged structures as lists. + """ + img = s["image"] + masks = s["masks"] # list of [H,W] + points = s["points"] # list of [Pi,2] + labels = s["labels"] # list of [Pi] + boxes = s["boxes"] # list of [4] + + H, W = img.shape[:2] + r = min(max_res / max(H, W), 1.0) # cap longest side; don't upscale + newH, newW = int(round(H * r)), int(round(W * r)) + + if (newH, newW) != (H, W): + img = cv2.resize(img, (newW, newH), interpolation=cv2.INTER_LINEAR) + masks = [cv2.resize(m.astype(np.uint8), (newW, newH), interpolation=cv2.INTER_NEAREST) for m in masks] + pts = [np.asarray(p, dtype=np.float32) * r for p in points] # scale each instance + bxs = [np.asarray(b, dtype=np.float32) * r for b in boxes] + else: + pts = [np.asarray(p, dtype=np.float32) for p in points] + bxs = [np.asarray(b, dtype=np.float32) for b in boxes] + + # labels stay as-is per instance + lbls = [np.asarray(l) for l in labels] + + return { + "image": img, + "masks": masks, + "points": pts, + "labels": lbls, + "boxes": bxs, + } + +def _to_numpy_mask_stack(masks): + """ + Accepts list/tuple of tensors or np arrays shaped [H,W]; + returns np.uint8 array [N,H,W] with values {0,1}. + """ + if isinstance(masks, np.ndarray) and masks.ndim == 3: + arr = masks + else: + arr = np.stack([m.detach().cpu().numpy() if hasattr(m, "detach") else np.asarray(m) + for m in masks], axis=0) + # binarize & cast + arr = (arr > 0).astype(np.uint8) + return arr + +def visualize_item_with_points(image, masks, points, boxes=None, + title=None, point_size=24): + """ + Show a single image with all component masks (colored), + all positive points (color-matched to masks), and optional boxes. + + Args + ---- + image : np.ndarray (H,W) or (H,W,3) (uint8 or float) + masks : list[np.ndarray or torch.Tensor] each [H,W] in {0,1} + points: list[np.ndarray] each [P,2] as (x,y) + boxes : list[np.ndarray] each [4] xyxy (optional) + """ + # normalize inputs + mstack = _to_numpy_mask_stack(masks) # [N,H,W] in {0,1} + + # set up figure + fig, ax = plt.subplots(1, 1, figsize=(5, 5)) + if image.ndim == 2: + ax.imshow(image, cmap="gray", interpolation="nearest") + else: + ax.imshow(image, interpolation="nearest") + + # overlay masks with your helper + add_masks(mstack, ax) # uses get_colors internally + + # color-match points (and boxes) to mask color + colors = get_colors(len(masks)) + for i, pts in enumerate(points): + if pts is None or len(pts) == 0: + continue + color = colors[i % len(colors)] + ax.scatter(pts[:, 0], pts[:, 1], s=point_size, + c=[(color[0], color[1], color[2], 1.0)], + edgecolors="k", linewidths=0.5) + + if boxes is not None: + for i, b in enumerate(boxes): + if b is None: + continue + color = colors[i % len(colors)] + x0, y0, x1, y1 = map(float, b) + rect = patches.Rectangle( + (x0, y0), x1 - x0, y1 - y0, + linewidth=2, + edgecolor=(color[0], color[1], color[2], 1.0), + facecolor="none" + ) + ax.add_patch(rect) + + if title: + ax.set_title(title) + ax.axis("off") + plt.tight_layout() + +def save_training_log(results, outdir="results", metric_keys=["ABIoU"]): + + # CSV (epoch-aligned, pad with blanks if needed) + path = os.path.join(outdir, "metrics.csv") + is_new = not os.path.exists(path) + + with open(path, "a", newline="") as f: + writer = csv.DictWriter(f, + fieldnames=["epoch", "lr_mask", "lr_prompt", "train_loss", "val_loss", + *metric_keys, "train_iou", "train_dice", "train_mask", "val_iou", + "val_dice", "val_mask"]) + if is_new: + writer.writeheader() + writer.writerow({ + "epoch": int(results['epoch']), + "lr_mask": f"{results['lr_mask']:.1e}", + "lr_prompt": f"{results['lr_prompt']:.1e}", + "train_loss": float(results['train']['loss_total']), + "val_loss": float(results['val']['loss_total']), + **{k: float(results['val'][k]) for k in metric_keys}, + "train_iou": float(results['train']['loss_iou']), + "train_dice": float(results['train']['loss_dice']), + "train_mask": float(results['train']['loss_mask']), + "val_iou": float(results['val']['loss_iou']), + "val_dice": float(results['val']['loss_dice']), + "val_mask": float(results['val']['loss_mask']), + }) + +######################################################################################## + +def _orig_hw_tuple(pred): + ohw = pred._orig_hw + # Cases seen in SAM/SAM2 wrappers: (H,W), [H,W], [(H,W)], possibly numpy ints + if isinstance(ohw, list): + # list of tuples -> pick the last tuple + if len(ohw) and isinstance(ohw[-1], (list, tuple)): + h, w = ohw[-1] + else: + # flat [H, W] + h, w = ohw[0], ohw[1] + elif isinstance(ohw, tuple): + h, w = ohw + else: + # e.g., numpy array-like + h, w = ohw[0], ohw[1] + return (int(h), int(w)) + +@torch.no_grad() +def infer_on_single_image( + predictor, + image, + inst_points, # list[Tensor(Pi,2)] + inst_labels, # list[Tensor(Pi)] + inst_masks=None, # optional list[H,W] + inst_boxes=None, # list[Tensor(4)] or None + use_boxes=True, + predict_multimask=False, + device="cuda", +): + # 1) Encode image once + predictor.set_image(image) + + # 2) Normalize cached features to [1,C,H',W'] + def _to_batched_4d(x): + if isinstance(x, (list, tuple)): + x = x[0] + if x.dim() == 3: x = x.unsqueeze(0) + return x.to(device) + + image_embeddings = _to_batched_4d(predictor._features["image_embed"]) # [1,C,H′,W′] + high_res_features = [_to_batched_4d(lvl) for lvl in predictor._features["high_res_feats"]] + + # 3) Pack prompts to (N,P,2) / (N,P) + pts_all = [torch.as_tensor(p, device=device, dtype=torch.float32) for p in inst_points] + lbl_all = [torch.as_tensor(l, device=device, dtype=torch.float32) for l in inst_labels] + N = len(pts_all) + if N == 0: + return None, None, None + + P = max(p.shape[0] for p in pts_all) + pts_pad = torch.zeros((N, P, 2), device=device) + lbl_pad = torch.full((N, P), -1.0, device=device) + for i,(p,l) in enumerate(zip(pts_all, lbl_all)): + pts_pad[i, :p.shape[0]] = p + lbl_pad[i, :l.shape[0]] = l + + boxes = None + if use_boxes and inst_boxes: + boxes = torch.stack([torch.as_tensor(bx, device=device, dtype=torch.float32) for bx in inst_boxes], dim=0) + + # 4) Prompt encoding + _, unnorm_coords, labels, _ = predictor._prep_prompts( + pts_pad, lbl_pad, box=boxes, mask_logits=None, normalize_coords=True + ) + sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder( + points=(unnorm_coords, labels), + boxes=boxes if use_boxes else None, + masks=None, + ) + + # 5) Decode with repeat_image=True ✅ no manual feature tiling + low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder( + image_embeddings=image_embeddings, # [1,C,H′,W′] + image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, # [N, ...] + dense_prompt_embeddings=dense_embeddings, # [N, ...] + multimask_output=predict_multimask, + repeat_image=True, # <-- let decoder repeat internally + high_res_features=high_res_features, # each [1,C,H′,W′] + ) + + # 6) Upscale + optional stack GT + def _orig_hw_tuple(pred): + ohw = pred._orig_hw + if isinstance(ohw, list) and len(ohw) and isinstance(ohw[-1], (list, tuple)): + return (int(ohw[-1][0]), int(ohw[-1][1])) + if isinstance(ohw, tuple): + return (int(ohw[0]), int(ohw[1])) + return tuple(int(v) for v in ohw) + + out_hw = _orig_hw_tuple(predictor) # (H, W) e.g., (928, 960) + prd_masks = predictor._transforms.postprocess_masks(low_res_masks, out_hw) # [N,K,H,W] + + gt_masks = None + if inst_masks: + gt_masks = torch.stack([torch.as_tensor(m, device=device).float() for m in inst_masks], dim=0) + + return prd_masks, prd_scores, gt_masks + +@torch.no_grad() +def _binary_iou(a: torch.Tensor, b: torch.Tensor, eps=1e-6) -> torch.Tensor: + # a,b: (N,H,W) boolean/0-1 + inter = (a & b).float().sum(dim=(1,2)) + uni = (a | b).float().sum(dim=(1,2)).clamp_min(eps) + return inter / uni + +@torch.no_grad() +def stability_from_logits(mask_logits: torch.Tensor, delta: float = 0.05) -> torch.Tensor: + """ + mask_logits: (K,H,W) raw logits from the decoder. + We compute IoU between masks thresholded at 0 and +/- delta in logit-space. + Returns: (K,) stability score in [0,1]. + """ + # Note: comparing binary maps at thresholds t1 and t2 is enough; no need to sigmoid. + m_lo = (mask_logits > -delta) # t = -delta + m_hi = (mask_logits > +delta) # t = +delta + # Using IoU between the two gives the 'stability' used by AMG + return _binary_iou(m_lo, m_hi) + +@torch.no_grad() +def dynamic_multimask_predict( + predictor, + image_batch_entry, + points=None, labels=None, box=None, mask_input=None, + stability_thresh: float = 0.98, + stability_delta: float = 0.05, + multimask_k: int = 3, +): + """ + 1) Run single-mask decode (multimask_output=False). + 2) Compute stability; if stable >= thresh, keep single. + 3) Otherwise rerun with multimask_output=True to get K candidates. + + Returns: + prd_masks: (K,H,W) float logits (not sigmoid) **if multimask**, else (1,H,W) + prd_scores: (K,) IoU-head predictions (or (1,)) + aux: dict with {'stability': tensor} + """ + # First pass: single mask + out1 = predictor.predict( + image=image_batch_entry, + point_coords=points, point_labels=labels, + box=box, mask_input=mask_input, + multimask_output=False, # <-- single + return_logits=True, # <-- to compute stability + return_iou_predictions=True, # <-- IoU head (SAM2) + ) + logits1 = out1["masks_logits"][0] # (1,H,W) + iou1 = out1["iou_predictions"][0] # (1,) + stab1 = stability_from_logits(logits1, delta=stability_delta) # (1,) + + if stab1[0].item() >= stability_thresh: + return logits1, iou1, {"stability": stab1} + + # Not stable -> get multimask + outk = predictor.predict( + image=image_batch_entry, + point_coords=points, point_labels=labels, + box=box, mask_input=mask_input, + multimask_output=True, # <-- gated on stability + num_multimask_outputs=multimask_k, # if your API supports it; otherwise ignored + return_logits=True, + return_iou_predictions=True, + ) + logitsk = outk["masks_logits"][0] # (K,H,W) + iouk = outk["iou_predictions"][0] # (K,) + stabk = stability_from_logits(logitsk, delta=stability_delta) # (K,) + return logitsk, iouk, {"stability": stabk} \ No newline at end of file diff --git a/saber/finetune/losses.py b/saber/finetune/losses.py new file mode 100644 index 0000000..d9e1950 --- /dev/null +++ b/saber/finetune/losses.py @@ -0,0 +1,163 @@ +import torch.nn.functional as F +import torch.nn as nn +import torch + +def dice_loss_from_logits(logits, targets, eps=1e-6): + probs = torch.sigmoid(logits) + inter = (probs * targets).sum(dim=(1, 2)) + denom = probs.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) + return 1 - (2 * inter + eps) / (denom + eps) + +def focal_loss_from_logits(logits, targets, alpha=0.25, gamma=2.0): + # logits, targets: (N,H,W) float with targets in {0,1} + ce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none") + p = torch.sigmoid(logits) + pt = p * targets + (1 - p) * (1 - targets) + w = alpha * targets + (1 - alpha) * (1 - targets) + return (w * ((1 - pt) ** gamma) * ce).mean(dim=(1, 2)) # -> (N,) + +class MultiMaskIoULoss(nn.Module): + def __init__( + self, + iou_regression: str = "l1", # "l1" or "mse" + supervise_all_iou: bool = False, + all_iou_weight: float = 0.1, + # Weights: match the (dice=20, focal=1, iou=1, class=1) convention + weight_dict: dict = None, + focal_alpha: float = 0.25, + focal_gamma: float = 2.0, + # Objectness head + pred_obj_scores: bool = True, + focal_alpha_obj: float = 0.5, # can be -1 to disable alpha weighting + focal_gamma_obj: float = 0.0, # 0 -> plain BCE on objectness + gate_by_objectness: bool = True, # gate seg/IoU losses when no object + ): + super().__init__() + if weight_dict is None: + weight_dict = {"loss_mask": 1.0, "loss_dice": 20.0, "loss_iou": 1.0, "loss_class": 1.0} + # NOTE: "loss_mask" == focal; "loss_dice" == dice; "loss_class" == objectness + self.focal_weight = weight_dict.get("loss_mask", 1.0) + self.dice_weight = weight_dict.get("loss_dice", 20.0) + self.iou_head_weight = weight_dict.get("loss_iou", 1.0) + self.class_weight = weight_dict.get("loss_class", 1.0) + + self.iou_regression = iou_regression + self.supervise_all_iou = supervise_all_iou + self.all_iou_weight = all_iou_weight + self.focal_alpha = focal_alpha + self.focal_gamma = focal_gamma + + self.pred_obj_scores = pred_obj_scores + self.focal_alpha_obj = focal_alpha_obj + self.focal_gamma_obj = focal_gamma_obj + self.gate_by_objectness = gate_by_objectness + + def _reg_loss(self, pred, target): + return F.mse_loss(pred, target) if self.iou_regression == "mse" else F.l1_loss(pred, target) + + @staticmethod + def _binary_iou(bin_mask, gt_bool, eps=1e-6): + inter = (bin_mask & gt_bool).float().sum(dim=(1, 2)) + union = (bin_mask | gt_bool).float().sum(dim=(1, 2)).clamp_min(eps) + return inter / union + + def _objectness_targets(self, gt_mask): + # gt_mask: (H,W) bool/0-1 → scalar 0/1 (any positive pixel?) + return (gt_mask.sum() > 0).float() + + def _objectness_loss(self, logits_scalar, target_scalar): + # logits_scalar: shape () or (1,) + # Use focal-like BCE if gamma>0, else plain BCE + if self.focal_gamma_obj == 0.0 and self.focal_alpha_obj < 0: + return F.binary_cross_entropy_with_logits(logits_scalar, target_scalar) + p = torch.sigmoid(logits_scalar) + ce = F.binary_cross_entropy_with_logits(logits_scalar, target_scalar, reduction="none") + pt = p * target_scalar + (1 - p) * (1 - target_scalar) + if self.focal_alpha_obj >= 0: + w = self.focal_alpha_obj * target_scalar + (1 - self.focal_alpha_obj) * (1 - target_scalar) + else: + w = 1.0 + return (w * ((1 - pt) ** self.focal_gamma_obj) * ce).mean() + + def forward(self, prd_masks_logits, prd_iou_scores, gt_masks, object_score_logits=None): + """ + prd_masks_logits: (N,K,H,W) or (K,H,W) + prd_iou_scores: (N,K) or (K,) + gt_masks: (N,H,W) or (H,W) + object_score_logits: (N,) or (N,1) or scalar per instance (optional) + """ + # ---- normalize shapes to batched form ---- + if prd_masks_logits.dim() == 3: # (K,H,W) -> (1,K,H,W) + prd_masks_logits = prd_masks_logits.unsqueeze(0) + if prd_iou_scores.dim() == 1: # (K,) -> (1,K) + prd_iou_scores = prd_iou_scores.unsqueeze(0) + if gt_masks.dim() == 2: # (H,W) -> (1,H,W) + gt_masks = gt_masks.unsqueeze(0) + + N, K, H, W = prd_masks_logits.shape + gt_bool = gt_masks.bool() # (N,H,W) + gt_float = gt_masks.float() # (N,H,W) + + # ---- IoU per slot (vectorized) ---- + bin_masks = (prd_masks_logits > 0) # (N,K,H,W) + gt_bool_b = gt_bool.unsqueeze(1) # (N,1,H,W) <-- key fix + inter = (bin_masks & gt_bool_b).float().sum(dim=(2,3)) # (N,K) + union = (bin_masks | gt_bool_b).float().sum(dim=(2,3)).clamp_min(1e-6) # (N,K) + true_iou_per_k = inter / union # (N,K) + + # ---- pick best slot per instance ---- + best_ix = torch.argmax(true_iou_per_k, dim=1) # (N,) + idx_4d = best_ix.view(N,1,1,1).expand(N,1,H,W) + best_logits = prd_masks_logits.gather(1, idx_4d).squeeze(1) # (N,H,W) + + # ---- segmentation losses on best slot ---- + dice_per = dice_loss_from_logits(best_logits, gt_float) # (N,) + focal_per = focal_loss_from_logits( + best_logits, gt_float, + alpha=self.focal_alpha, gamma=self.focal_gamma + ) # (N,) + seg_loss = self.dice_weight * dice_per.mean() + self.focal_weight * focal_per.mean() + + # ---- IoU head regression ---- + iou_pred_best = prd_iou_scores.gather(1, best_ix.view(N,1)).squeeze(1) # (N,) + iou_target_best = true_iou_per_k.gather(1, best_ix.view(N,1)).squeeze(1) # (N,) + + # IMPORTANT: prd_iou_scores are already probs in [0,1]; do NOT apply sigmoid here. + # Regress directly (L1 or MSE) to the true IoU. + iou_reg_main = self._reg_loss(iou_pred_best, iou_target_best) + + total = seg_loss + self.iou_head_weight * iou_reg_main + + # ---- optional gentle supervision over all K ---- + iou_reg_all = None + if self.supervise_all_iou and K > 1 and self.all_iou_weight > 0: + # SAME: no sigmoid; train all K scores directly toward their per-slot IoUs + iou_reg_all = self._reg_loss(prd_iou_scores, true_iou_per_k) + total = total + self.all_iou_weight * iou_reg_all + + # ---- objectness (per instance) ---- + obj_loss = None + if self.pred_obj_scores: + assert object_score_logits is not None, "object_score_logits required when pred_obj_scores=True" + obj_logits = object_score_logits.view(N, -1).mean(dim=1) # (N,) + obj_target = (gt_bool.view(N, -1).sum(dim=1) > 0).float() # (N,) + obj_loss = self._objectness_loss(obj_logits, obj_target) + total = total + self.class_weight * obj_loss + + if self.gate_by_objectness: + has_obj = (obj_target > 0.5).float() # (N,) + gate = has_obj.mean().clamp_min(1e-6) + total = self.class_weight * obj_loss + gate * (seg_loss + self.iou_head_weight * iou_reg_main) + if iou_reg_all is not None: + total = total + gate * (self.all_iou_weight * iou_reg_all) + + return { + "loss_total": total, + "loss_seg": seg_loss.detach(), + "loss_iou": iou_reg_main.detach(), + "loss_iou_all": (iou_reg_all.detach() if iou_reg_all is not None else None), + "loss_class": (obj_loss.detach() if obj_loss is not None else None), + "loss_mask": focal_per.mean().detach(), + "loss_dice": dice_per.mean().detach(), + "true_iou_best": iou_target_best.mean().detach(), + } \ No newline at end of file diff --git a/saber/finetune/metrics.py b/saber/finetune/metrics.py new file mode 100644 index 0000000..2a3d906 --- /dev/null +++ b/saber/finetune/metrics.py @@ -0,0 +1,249 @@ +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from typing import Dict, Any, List, Tuple +import torch.nn.functional as F +import numpy as np +import torch + +# Subset of IoU thresholds, as requested: +AR_THRESHOLDS = np.array([0.50, 0.65, 0.75, 0.85], dtype=np.float32) + +# ------------------------ Decoder-side helpers ------------------------ + +def _binary_iou(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + Fast IoU for binary masks (boolean or {0,1} tensors). + Shapes: a,b: (H, W) + """ + inter = (a & b).float().sum() + uni = (a | b).float().sum().clamp_min(eps) + return inter / uni + +@torch.no_grad() +def decoder_prompt_miou(prd_masks: torch.Tensor, gt_masks: torch.Tensor) -> float: + """ + Best-of-K decoder prompt mIoU. + Args: + prd_masks: [N, K, H, W] LOGITS from the decoder + gt_masks: [N, H, W] binary {0,1} + Returns: + mean over N of (max IoU over K), using SAM2's thresholding rule (logits > 0). + """ + N, K, H, W = prd_masks.shape + # SAM2 convention: threshold decoder outputs by logits > 0 + pred_bin = (prd_masks > 0) # bool, [N,K,H,W] + ious = [] + for n in range(N): + gt = gt_masks[n].bool() + if gt.sum() == 0: + continue # skip empty GT + best = torch.stack([_binary_iou(pred_bin[n, k], gt) for k in range(K)], dim=0).max() + ious.append(best) + if len(ious) == 0: + return float("nan") + return float(torch.stack(ious).mean().item()) + +@torch.no_grad() +def iou_head_calibration_from_decoder( + prd_masks: torch.Tensor, + prd_scores: torch.Tensor, + gt_masks: torch.Tensor, + num_bins: int = 15, +) -> Dict[str, Any]: + """ + Compare predicted IoU (sigmoid(prd_scores)) vs true IoU (from logits>0 masks). + Args: + prd_masks: [N,K,H,W] LOGITS + prd_scores: [N,K] raw IoU logits + gt_masks: [N,H,W] + Returns: + dict with calibration MAE, Brier, ECE, and a per-bin table (for diagnostics). + """ + N, K, H, W = prd_masks.shape + preds, trues = [], [] + pred_bin = (prd_masks > 0) # bool + + for n in range(N): + gt = gt_masks[n].bool() + if gt.sum() == 0: + continue + # True IoU per K proposal + true_iou_k = torch.stack([_binary_iou(pred_bin[n, k], gt) for k in range(K)], dim=0) # [K] + pred_iou_k = prd_scores[n].sigmoid().clamp(0, 1) # [K] + trues.append(true_iou_k) + preds.append(pred_iou_k) + + if len(preds) == 0: + return {"mae": float("nan"), "brier": float("nan"), "ece": float("nan"), "table": []} + + preds = torch.cat(preds) # [N'*K] + trues = torch.cat(trues) # [N'*K] + + mae = torch.abs(preds - trues).mean().item() + brier = torch.mean((preds - trues) ** 2).item() + + # Expected Calibration Error (ECE) over uniform bins + bins = np.linspace(0.0, 1.0, num_bins + 1) + pred_np, true_np = preds.cpu().numpy(), trues.cpu().numpy() + idx = np.clip(np.digitize(pred_np, bins, right=True) - 1, 0, num_bins - 1) + + ece = 0.0 + table = [] + total = len(pred_np) + for b in range(num_bins): + m = (idx == b) + n_b = int(m.sum()) + if n_b == 0: + table.append({"bin": f"[{bins[b]:.2f},{bins[b+1]:.2f})", "count": 0, + "mean_pred": None, "mean_true": None, "gap": None}) + continue + mp = float(pred_np[m].mean()) + mt = float(true_np[m].mean()) + gap = abs(mp - mt) + ece += (n_b / total) * gap + table.append({"bin": f"[{bins[b]:.2f},{bins[b+1]:.2f})", "count": n_b, + "mean_pred": round(mp, 4), "mean_true": round(mt, 4), "gap": round(gap, 4)}) + + return {"mae": mae, "brier": float(brier), "ece": float(ece), "table": table} + +# ------------------------ AMG proposal metrics ------------------------ + +def _iou(a: np.ndarray, b: np.ndarray) -> float: + """IoU for boolean numpy masks (H,W).""" + inter = np.logical_and(a, b).sum() + union = np.logical_or(a, b).sum() + return float(inter) / max(1.0, float(union)) + +def _iou_matrix(preds: List[np.ndarray], gts: List[np.ndarray]) -> np.ndarray: + """ + Build P x G IoU matrix for numpy boolean masks. + preds: list of predicted masks (H,W) + gts: list of gt masks (H,W) + """ + if len(preds) == 0 or len(gts) == 0: + return np.zeros((len(preds), len(gts)), dtype=np.float32) + M = np.zeros((len(preds), len(gts)), dtype=np.float32) + for i, p in enumerate(preds): + for j, g in enumerate(gts): + M[i, j] = _iou(p, g) + return M + +def _greedy_match(M: np.ndarray, tau: float) -> Tuple[int, int, int]: + """ + Greedy bipartite matching by IoU descending with threshold tau. + Returns: + TP, FP, FN + """ + P, G = M.shape + used_p, used_g, matches = set(), set(), [] + pairs = [(i, j, M[i, j]) for i in range(P) for j in range(G)] + pairs.sort(key=lambda x: x[2], reverse=True) + for i, j, iou in pairs: + if iou < tau: + break + if i in used_p or j in used_g: + continue + used_p.add(i); used_g.add(j); matches.append((i, j)) + TP = len(matches); FP = P - TP; FN = G - TP + return TP, FP, FN + +def average_recall_amg( + amg_outputs: List[List[dict]], + gt_masks_list: List[List[torch.Tensor]], + iou_thresholds: np.ndarray = AR_THRESHOLDS, + max_proposals: int = None, +) -> Dict[str, Any]: + """ + Average Recall across IoU thresholds. + Args: + amg_outputs: list over images of list[mask_dict]; each dict has 'segmentation' (np.bool array) and 'predicted_iou' + gt_masks_list: list over images of list[tensor HxW] for ground-truth instances + iou_thresholds: numpy array of IoU taus to average over + max_proposals: cap #proposals per image (after ranking by predicted_iou) for speed + Returns: + {"AR": scalar, "per_tau_recall": {tau: recall_tau}} + """ + recalls = [] + for tau in iou_thresholds: + tp = fn = 0 + for img_masks, gts in zip(amg_outputs, gt_masks_list): + # rank by predicted_iou then cap + if max_proposals is not None: + img_masks = sorted(img_masks, key=lambda d: d.get('predicted_iou', 0.0), reverse=True)[:max_proposals] + preds = [m['segmentation'].astype(bool) for m in img_masks] + gts_np = [g.cpu().numpy().astype(bool) for g in gts] + M = _iou_matrix(preds, gts_np) + tpi, _, fni = _greedy_match(M, float(tau)) + tp += tpi; fn += fni + denom = tp + fn + recalls.append(tp / denom if denom > 0 else np.nan) + + per_tau = {float(t): (None if np.isnan(r) else float(r)) for t, r in zip(iou_thresholds, recalls)} + return {"AR": float(np.nanmean(recalls)), "per_tau_recall": per_tau} + +def recall_at_k_amg( + amg_outputs: List[List[dict]], + gt_masks_list: List[List[torch.Tensor]], + ks: Tuple[int, ...] = (10, 50, 100), + iou_thresh: float = 0.5, +) -> Dict[str, Any]: + """ + Recall@K at a fixed IoU threshold (default 0.5). + """ + out = {} + for K in ks: + tp = fn = 0 + for img_masks, gts in zip(amg_outputs, gt_masks_list): + sel = sorted(img_masks, key=lambda d: d.get('predicted_iou', 0.0), reverse=True)[:K] + preds = [m['segmentation'].astype(bool) for m in sel] + gts_np = [g.cpu().numpy().astype(bool) for g in gts] + M = _iou_matrix(preds, gts_np) + tpi, _, fni = _greedy_match(M, iou_thresh) + tp += tpi; fn += fni + denom = tp + fn + out[K] = tp / denom if denom > 0 else float('nan') + return {"Recall@K": out, "iou_thresh": iou_thresh} + +# ------------------------ Wrapper for validation loop ------------------------ + +@torch.no_grad() +def sam2_metrics(batch: Dict[str, Any], outputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], amg) -> Dict[str, Any]: + """ + Compute a SAM2-style metric bundle for a single validation batch. + Args: + batch: {"images": list[np.ndarray or torch tensor], "masks": list[list[torch.Tensor HxW]]} + outputs: (prd_masks, prd_scores, gt_masks) where + prd_masks: [N,K,H,W] logits, prd_scores: [N,K] logits, gt_masks: [N,H,W] {0,1} + amg: a SAM2AutomaticMaskGenerator instance + Returns: + dict with prompt_mIoU, IoU calibration (mae/brier/ece + table), + AR averaged over {0.50,0.65,0.75,0.85}, per-threshold recalls, + and Recall@{10,50,100}. Includes "num_images" for weighted averaging. + """ + prd_masks, prd_scores, gt_masks = outputs[:3] + + # Decoder-side metrics (cheap; no extra forward) + pm = decoder_prompt_miou(prd_masks, gt_masks) + cal = iou_head_calibration_from_decoder(prd_masks, prd_scores, gt_masks, num_bins=15) + + # AMG proposals (dominates runtime; run once, reuse for AR and R@K) + all_amg = [amg.generate(img) for img in batch["images"]] + all_gt = batch["masks"] + + ar = average_recall_amg(all_amg, all_gt, iou_thresholds=AR_THRESHOLDS, max_proposals=200) + rK = recall_at_k_amg(all_amg, all_gt, ks=(10, 50, 100), iou_thresh=0.5) + + return { + "prompt_miou": float(pm) if pm == pm else float("nan"), + "cal_mae": cal["mae"], + "cal_brier": cal["brier"], + "cal_ece": cal["ece"], + "cal_tables": cal["table"], # keep for inspection; don’t reduce across ranks + "AR": ar["AR"], + "num_images": len(batch["images"]), + "R@10": rK["Recall@K"][10], + "R@50": rK["Recall@K"][50], + "R@100": rK["Recall@K"][100], + } + # # "per_tau_recall": ar["per_tau_recall"], # keep for plotting; don’t reduce as a scalar + # "R@10": rK["Recall@K"][10], + # "R@50": rK["Recall@K"][50], \ No newline at end of file diff --git a/saber/finetune/prep.py b/saber/finetune/prep.py new file mode 100644 index 0000000..4d58ab3 --- /dev/null +++ b/saber/finetune/prep.py @@ -0,0 +1,318 @@ +""" +Convert copick data to 3D zarr format with simple JSON segmentation mapping. +Saves full 3D volumes and creates a simple ID-to-organelle mapping sorted by volume. +""" + +from saber.utils.zarr_writer import add_attributes +from typing import Dict, List, Optional, Tuple +from copick_utils.io import readers +from skimage.measure import label +import copick, json, click +from tqdm import tqdm +import numpy as np +import zarr + + +def process_run_3d_simple( + run, + voxel_spacing: float, + tomo_alg: str, + organelle_names: List[str], + min_component_volume: int = 100, + user_id: Optional[str] = None, + values: Dict[str, int] = {}, +) -> Tuple[np.ndarray, np.ndarray, Dict[str, str]]: + """ + Process a single run in full 3D with simple volume-based sorting. + + Returns: + volume_3d: Full 3D tomogram + seg_3d: 3D segmentation with unique labels sorted by volume + id_to_organelle: Dictionary mapping mask indices to organelle names + """ + + click.echo(" Loading tomogram...") + + # Get tomogram data + vs = run.get_voxel_spacing(voxel_spacing) + tomograms = vs.get_tomograms(tomo_alg) + + if not tomograms: + raise ValueError(f"No tomograms found for run {run.name}") + + volume_3d = tomograms[0].numpy() + click.echo(f" Tomogram shape: {volume_3d.shape}") + + components = [] + offset = 1 + temp_seg_3d = np.zeros_like(volume_3d, dtype=np.uint32) + + for organelle_name in organelle_names: + try: + click.echo(f" Processing {organelle_name}...") + seg = readers.segmentation(run, voxel_spacing, organelle_name, user_id=user_id) + + if seg is None: + click.echo(" No segmentation found") + continue + elif seg.shape != volume_3d.shape: + print(f"[Warning, {run.name}] Segmentation shape {seg.shape} does not match tomogram shape {volume_3d.shape}") + continue + elif values: + offset = values[organelle_name] + temp_seg_3d[seg > 0.5] = offset + components.append((organelle_name, offset, np.sum(seg > 0.5))) + else: + # Convert to binary and separate connected components + binary_mask = (seg > 0.5).astype(np.uint8) + labeled_mask = label(binary_mask, connectivity=3) + + # Most efficient: process in a single pass + unique_labels, counts = np.unique(labeled_mask[labeled_mask > 0], return_counts=True) + + for label_val, vol in zip(unique_labels, counts): + if vol >= min_component_volume: + temp_seg_3d[labeled_mask == label_val] = offset + components.append((organelle_name, offset, vol)) + offset += 1 + except Exception as e: + click.echo(f" Error processing {organelle_name}: {e}") + + # Sort by volume (smallest first, so small objects are on top in GUI) + components.sort(key=lambda x: x[2]) + + # Create final segmentation with remapped labels and the mapping dictionary + seg_3d = np.zeros_like(temp_seg_3d, dtype=np.uint16) + id_to_organelle: Dict[str, str] = {} + for new_label, (organelle_name, old_label, _volume) in enumerate(components, start=1): + # Only remap if we have predefined values, else keep sequential values + if values: new_label = old_label + seg_3d[temp_seg_3d == old_label] = new_label + id_to_organelle[str(new_label)] = organelle_name + + return volume_3d, seg_3d, id_to_organelle + + +def convert_copick_to_3d_zarr( + config_path: str, + output_zarr_path: str, + output_json_path: Optional[str], + voxel_spacing: float, + tomo_alg: str, + specific_runs: str, + min_component_volume: int, + user_id: Optional[str], + keep_labels: bool +): + """ + Convert copick data to 3D zarr format with JSON segmentation mapping. + """ + + # Initialize copick + root = copick.from_file(config_path) + + # Get organelle names (non-particle objects) + organelle_names = [obj.name for obj in root.pickable_objects if not obj.is_particle] + + # Optional: filter out specific organelles + organelle_names = [x for x in organelle_names if "membrane" not in x] + click.echo(f"Found organelles: {organelle_names}") + + # Prepare organelle values if keeping labels + if keep_labels: + organelle_values = {obj.name: obj.label for obj in root.pickable_objects if not obj.is_particle} + else: + organelle_values = {} + + # Set default JSON output path + if output_json_path is None: + output_json_path = output_zarr_path.replace(".zarr", "_mapping.json") + + # Initialize zarr store + compressor = zarr.Blosc(cname="zstd", clevel=2) + store = zarr.DirectoryStore(output_zarr_path) + zroot = zarr.group(store=store, overwrite=True) + + # Master mapping dictionary + master_mapping: Dict[str, Dict[str, str]] = {} + + # Determine which runs to process + runs_to_process = specific_runs.split(",") if specific_runs else [run.name for run in root.runs] + + for run_name in tqdm(runs_to_process, desc="Processing runs"): + click.echo(f"\nProcessing run: {run_name}") + run = root.get_run(run_name) + + try: + # Process the 3D data + volume_3d, seg_3d, id_to_organelle = process_run_3d_simple( + run=run, + voxel_spacing=voxel_spacing, + tomo_alg=tomo_alg, + organelle_names=organelle_names, + min_component_volume=min_component_volume, + user_id=user_id, + values = organelle_values + ) + + # Create zarr group for this run + run_group = zroot.create_group(run_name) + + # Save 3D volume + voxel_spacing_nm = voxel_spacing / 10 # Convert to nm + run_group.create_dataset( + "0", + data=volume_3d, + chunks=(1, volume_3d.shape[1], volume_3d.shape[2]), # Chunk by slices + compressor=compressor, + dtype=volume_3d.dtype, + ) + add_attributes(run_group, voxel_spacing_nm, True, voxel_spacing_nm) + + # Save 3D segmentation/label volume + label_group = run_group.create_group("labels") + label_group.create_dataset( + "0", + data=seg_3d, + chunks=(1, volume_3d.shape[1], volume_3d.shape[2]), + compressor=compressor, + dtype=seg_3d.dtype, + ) + add_attributes(label_group, voxel_spacing_nm, True, voxel_spacing_nm) + + # Add to master mapping + master_mapping[run_name] = id_to_organelle + + click.echo(f" ✅ Saved {run_name} with {len(id_to_organelle)} segmentations") + + except Exception as e: + click.echo(f" ❌ Error processing {run_name}: {e}") + continue + + # Save master JSON mapping + if keep_labels: + zroot.attrs['label_values'] = organelle_values + else: + with open(output_json_path, "w") as f: + json.dump(master_mapping, f, indent=2) + + click.echo("\n🎉 Conversion complete!") + click.echo(f"📁 Zarr output: {output_zarr_path}") + click.echo(f"📄 JSON mapping: {output_json_path}") + click.echo(f"📊 Total runs processed: {len(master_mapping)}") + + +def load_3d_zarr_data(zarr_path: str, run_name: str) -> Tuple[np.ndarray, np.ndarray, Dict]: + """ + Load 3D data from zarr for a specific run. + + Returns: + volume: 3D tomogram + labels: 3D label volume + id_to_organelle: Mapping of label values to organelle names + """ + store = zarr.DirectoryStore(zarr_path) + zroot = zarr.group(store=store, mode="r") + + if run_name not in zroot: + raise ValueError(f"Run {run_name} not found in zarr") + + run_group = zroot[run_name] + + # NOTE: This mirrors the writer above: datasets saved under '0' and 'labels/0'. + volume = run_group["0"][:] + labels = run_group["labels"]["0"][:] + # If you want to persist per-run mapping inside zarr attrs instead of the external JSON, + # you can set and read it here. Currently, mappings are stored in the external JSON. + id_to_organelle = {} + + return volume, labels, id_to_organelle + + +@click.command(context_settings={"show_default": True}, name='prep') +@click.option( + "--config", + "config_path", + type=click.Path(exists=True, dir_okay=False, readable=True, path_type=str), + default="config.json", + help="Path to copick config file.", +) +@click.option( + "--output-zarr", + "output_zarr_path", + type=click.Path(dir_okay=True, writable=True, path_type=str), + required=True, + help="Output path for the zarr directory.", +) +@click.option( + "--output-json", + "output_json_path", + type=click.Path(dir_okay=False, writable=True, path_type=str), + default=None, + help="Output path for JSON mapping (defaults to _mapping.json).", +) +@click.option( + "--voxel-size", + "voxel_spacing", + type=float, + default=7.84, + help="Voxel spacing for the tomogram data (Å).", +) +@click.option( + "--tomo-alg", + type=str, + default="wbp-denoised-ctfdeconv", + help="Tomogram algorithm to use for processing.", +) +@click.option( + "--runIDs", + "specific_runs", + type=str, + default=None, + help="Process only specific runs. Provide a comma-separated list of runIDs.", +) +@click.option( + "--min-component-volume", + type=int, + default=1e3, + help="Minimum connected-component volume (in voxels).", +) +@click.option( + "--user-id", + type=str, + default=None, + help="UserID for accessing segmentation.", +) +@click.option( + "--keep-labels", + type=bool, + default=False, + help="Save Segmentations with Values Defined from the Copick Configuration File", +) +def main( + config_path: str, + output_zarr_path: str, + output_json_path: Optional[str], + voxel_spacing: float, + tomo_alg: str, + specific_runs: str, + min_component_volume: int, + user_id: Optional[str], + keep_labels: bool, + ): + + """Convert copick data to 3D zarr format with JSON segmentation mapping.""" + convert_copick_to_3d_zarr( + config_path=config_path, + output_zarr_path=output_zarr_path, + output_json_path=output_json_path, + voxel_spacing=voxel_spacing, + tomo_alg=tomo_alg, + specific_runs=specific_runs, + min_component_volume=min_component_volume, + keep_labels=keep_labels, user_id=user_id, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/saber/finetune/train.py b/saber/finetune/train.py new file mode 100644 index 0000000..a73699f --- /dev/null +++ b/saber/finetune/train.py @@ -0,0 +1,78 @@ +from saber.classifier.datasets.augment import get_finetune_transforms +from sam2.sam2_image_predictor import SAM2ImagePredictor +from saber.finetune.trainer import SAM2FinetuneTrainer +from saber.finetune.dataset import AutoMaskDataset +from saber.finetune.helper import collate_autoseg +from saber.utils.slurm_submit import sam2_inputs +from torch.utils.data import DataLoader +from sam2.build_sam import build_sam2 +from saber import pretrained_weights +from saber.utils import io +import click, os + +def finetune_sam2( + tomo_train: str = None, + fib_train: str = None, + tomo_val: str = None, + fib_val: str = None, + sam2_cfg: str = 'base', + num_epochs: int = 1000, + batch_size: int = 16): + """ + Finetune SAM2 on tomograms and FIBs + """ + + # Determine device + (cfg, checkpoint) = pretrained_weights.get_sam2_checkpoint(sam2_cfg) + sam2_model = build_sam2(cfg, checkpoint, device='cuda', postprocess_mask=False) + predictor = SAM2ImagePredictor(sam2_model) + + # Option 1 : Train the Mask Decoder and Prompt Encoder + predictor.model.sam_mask_decoder.train(True) + predictor.model.sam_prompt_encoder.train(True) + + # Load data loaders + train_loader = DataLoader( AutoMaskDataset( + tomo_train, fib_train, + transform=get_finetune_transforms()), + batch_size=batch_size, shuffle=True, + num_workers=4, pin_memory=True, collate_fn=collate_autoseg + ) + + val_loader = DataLoader( AutoMaskDataset( + tomo_val, fib_val, shuffle=False ), + num_workers=4, pin_memory=True, collate_fn=collate_autoseg, + batch_size=batch_size, shuffle=False ) if (tomo_val or fib_val) else train_loader + + # Initialize trainer and train + trainer = SAM2FinetuneTrainer( predictor, train_loader, val_loader ) + # trainer.train( num_epochs, best_metric='AR' ) + trainer.train( num_epochs ) + +@click.command(context_settings={"show_default": True}, name='run') +@sam2_inputs +@click.option("--fib-train", type=str, help="Path to train Zarr") +@click.option("--fib-val", type=str, help="Path to val Zarr") +@click.option("--tomo-train", type=str, help="Path to train Zarr") +@click.option("--tomo-val", type=str, help="Path to val Zarr") +@click.option("--epochs", type=int, default=1000, help="Number of epochs to train for") +@click.option('--batch-size', type=int, default=16, help="Batch size") +def finetune(sam2_cfg: str, epochs: int, fib_train: str, fib_val: str, tomo_train: str, tomo_val: str, batch_size: int): + """ + Finetune SAM2 on 3D Volumes. Images from input tomograms and fibs are generated with slabs and slices, respectively. + """ + + print("--------------------------------") + print( + f"Fine Tuning SAM2 on {fib_train} and {fib_val} and {tomo_train} and {tomo_val} for {epochs} epochs" + ) + print(f"Using SAM2 Config: {sam2_cfg}") + print(f"Tomo Train Zarr: {tomo_train}") + print(f"Tomo Val Zarr: {tomo_val}") + print(f"Fib Train Zarr: {fib_train}") + print(f"Fib Val Zarr: {fib_val}") + print(f"Using Number of Epochs: {epochs}") + print(f"Using Batch Size: {batch_size}") + print("--------------------------------") + + finetune_sam2(tomo_train, fib_train, tomo_val, fib_val, sam2_cfg, epochs, batch_size) \ No newline at end of file diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py new file mode 100644 index 0000000..a7c3412 --- /dev/null +++ b/saber/finetune/trainer.py @@ -0,0 +1,656 @@ +from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from saber.finetune.helper import save_training_log +from saber.finetune.abiou import automask_metrics +from saber.finetune.losses import MultiMaskIoULoss +from saber.finetune.metrics import sam2_metrics +import torch, os, optuna, random +import torch.nn.functional as F +from lightning import fabric +from tqdm import tqdm + +# Tests: +# Weights - +# Soften Mask Prompts Logits - ±3 + +class SAM2FinetuneTrainer: + def __init__(self, predictor, train_loader, val_loader, seed=42): + + # Store the predictor + self.predictor = predictor + + # Two parameter groups for different LRs (optional) + params = [ + # {"params": [p for p in self.predictor.model.image_encoder.parameters() + # if p.requires_grad], "lr": 6e-5}, + {"params": [p for p in self.predictor.model.sam_mask_decoder.parameters() + if p.requires_grad], "lr": 1e-4}, # 1e-4 + {"params": [p for p in self.predictor.model.sam_prompt_encoder.parameters() + if p.requires_grad], "lr": 5e-5}, # 5e-5 + ] + + # Initialize the optimizer and dataloaders + self.num_gpus = torch.cuda.device_count() + optimizer = torch.optim.AdamW(params, weight_decay=1e-3) + if self.num_gpus > 1: + self.fabric = fabric.Fabric(accelerator="cuda", strategy="ddp", devices=self.num_gpus) + self.fabric.launch() + self.predictor.model, self.optimizer = self.fabric.setup(self.predictor.model,optimizer) + self.predictor.model.mark_forward_method('forward_image') + self.autocast = self.fabric.autocast + self.use_fabric = True + else: + self.optimizer = optimizer + self.use_fabric = False + def _autocast(): + return torch.autocast(device_type="cuda", enabled=torch.cuda.is_available()) + self.autocast = _autocast + self.device = next(self.predictor.model.parameters()).device + + # Setup dataloaders + if self.use_fabric: + self.train_loader, self.val_loader = self.fabric.setup_dataloaders(train_loader, val_loader) + else: + self.train_loader, self.val_loader = train_loader, val_loader + + # Initialize the loss function + self.focal_alpha = 0.5 + self.focal_gamma = 2.0 + self.supervise_all_iou = True + self.iou_use_l1_loss = True + self.predict_multimask = True + + # Automask Generator Parameters + self.amg_kwargs = dict( + points_per_side=32, + points_per_batch=128, + pred_iou_thresh=0.7, + stability_score_thresh=0.7, + stability_score_offset=0.5, + crop_n_layers=0, + crop_n_points_downscale_factor=2, + box_nms_thresh=0.6, + use_m2m=False, + multimask_output=False, + ) + self.nAMGtrials = 10 + + # Initialize the use_boxes flag + self.use_boxes = True + self._rng = random.Random(seed) + + # Initialize the save directory + self.save_dir = 'results' + os.makedirs(self.save_dir, exist_ok=True) + + @property + def is_global_zero(self): + # True on single-process runs; Fabric guards inside when present + return (not self.use_fabric) or (self.use_fabric is not None and self.fabric.is_global_zero) + + @torch.no_grad() + def _stack_image_embeddings_from_predictor(self): + """ + After predictor.set_image_batch(images), gather stacked image embeddings + and high-res features for all B images. + Returns: + image_embeds: [B, C, H', W'] + hr_feats: list[level] of [B, C, H', W'] + """ + # image_embed is a list[len=B] of [C, H', W']; stack to [B, C, H', W'] + image_embeds = torch.stack(list(self.predictor._features["image_embed"]), dim=0).to(self.device) + + # high_res_feats is a list[level], where each level is a list[len=B] of [C, H', W'] + hr = self.predictor._features["high_res_feats"] + B = image_embeds.shape[0] + hr_feats = [torch.stack([lvl[b] for b in range(B)], dim=0).to(self.device) for lvl in hr] + return image_embeds, hr_feats + + def _determine_sampling(self, N, p_points=0.5, p_box=0.2, p_mask=0.2, p_mask_box=0.1): + """ + 0 = points only + 1 = box + points + 2 = mask + points + 3 = mask + box + points + """ + probs = [p_points, p_box, p_mask, p_mask_box] + s = sum(probs); probs = [p / s for p in probs] + # cumulative edges + e0, e1, e2 = probs[0], probs[0] + probs[1], probs[0] + probs[1] + probs[2] + + combo = [] + groups = {0: [], 1: [], 2: [], 3: []} + for i in range(N): + r = self._rng.random() + c = 0 if r < e0 else 1 if r < e1 else 2 if r < e2 else 3 + combo.append(c) + groups[c].append(i) + return combo, groups + + def forward_step(self, batch): + """ + Returns: + prd_masks: [N, K, H, W] (logits at image res) + prd_scores: [N, K] (predicted IoU/head) + gt_masks: [N, H, W] + inst_img_ix: [N] (which original image each instance came from) + """ + images = batch["images"] # list of B images (HxWx3); predictor handles types + B = len(images) + + # 1) encode once + self.predictor.set_image_batch(images) + image_embeds_B, hr_feats_B = self._stack_image_embeddings_from_predictor() + + # 2) flatten instances + inst_img_ix, gt_all, pts_all, lbl_all, box_all = [], [], [], [], [] + for b in range(B): + for m, p, l, bx in zip(batch["masks"][b], batch["points"][b], batch["labels"][b], batch["boxes"][b]): + inst_img_ix.append(b) + gt_all.append(m.to(self.device)) + pts_all.append(p.to(self.device)) + lbl_all.append(l.to(self.device)) + box_all.append(bx.to(self.device)) + + N = len(gt_all) + if N == 0: + return None, None, None, None + inst_img_ix = torch.tensor(inst_img_ix, device=self.device, dtype=torch.long) + + # (A) build raw per-instance tensors + boxes_full = torch.stack(box_all, dim=0).to(self.device, dtype=torch.float32) if len(box_all) > 0 else \ + torch.tensor([[0,0,1,1]], device=self.device, dtype=torch.float32).expand(N,4) + + # 5) mask logits (+/-6) + gt_masks_bin = torch.stack([m.to(torch.float32) for m in gt_all], dim=0).to(self.device) + mask_logits_full = (gt_masks_bin * 2.0 - 1.0) * 3.0 + + # pad clicks + max_p = max((p.shape[0] for p in pts_all), default=0) + pts_pad = torch.zeros((N, max_p, 2), device=self.device, dtype=torch.float32) + lbl_pad = torch.full((N, max_p), -1.0, device=self.device, dtype=torch.float32) + for i, (p, l) in enumerate(zip(pts_all, lbl_all)): + if p.numel(): + pts_pad[i, :p.shape[0]] = p.to(self.device, torch.float32) + lbl_pad[i, :l.shape[0]] = l.to(self.device, torch.float32) + + # (B) sampling + groups + combo, groups = self._determine_sampling(N) + + Hf, Wf = image_embeds_B.shape[-2], image_embeds_B.shape[-1] + target_mask_hw = (Hf * 4, Wf * 4) + + def _run_group(idxs, use_box, use_mask): + if not idxs: + return None + idxs_t = torch.tensor(idxs, device=self.device) + + # Slice features/images for this subgroup + img_emb = image_embeds_B[inst_img_ix[idxs_t]] # [n, C, Hf, Wf] + hr_sub = [lvl[inst_img_ix[idxs_t]] for lvl in hr_feats_B] + + # Slice raw prompts for this subgroup + pts_sub = pts_pad[idxs_t] # [n, P, 2] + lbl_sub = lbl_pad[idxs_t] # [n, P] + box_sub = boxes_full[idxs_t] if use_box else None # [n, 4] or None + mlog_sub = mask_logits_full[idxs_t] if use_mask else None # [n, Hm, Wm] or None + + # Get points + (normalized) boxes from the helper; do NOT pass mask here yet. + # This gives p_coords, p_labels, box_in in prompt space. + mask_in_dummy, p_coords, p_labels, box_in = self.predictor._prep_prompts( + pts_sub, lbl_sub, box=box_sub, mask_logits=None, normalize_coords=True + ) + + # Build a base mask prompt from GT (if requested) at decoder’s expected spatial size + def _resize_to_target(m): + if m is None: + return None + m = m.to(self.device, dtype=torch.float32) + if m.dim() == 3: # [n, H, W] -> [n,1,H,W] + m = m.unsqueeze(1) + if m.shape[1] != 1: # keep single-channel + m = m[:, :1] + if m.shape[-2:] != target_mask_hw: + m = F.interpolate(m, target_mask_hw, mode="bilinear", align_corners=False) + return m + + base_mask_in = _resize_to_target(mlog_sub) # GT-derived mask prompt (logits in ±6) + + # Optional soft coarsening for prompts (make it a tad blurrier + smaller magnitude) + def _soften_coarsen(mlog): # mlog: [n,1,H,W] logits + if mlog is None: + return None + mlog = mlog.clamp(-6, 6) * 0.5 # ≈ ±3 + H, W = mlog.shape[-2], mlog.shape[-1] + # down → up for slight coarsening + scale = torch.empty(1, device=mlog.device).uniform_(0.5, 0.8).item() + h2, w2 = max(8, int(H * scale)), max(8, int(W * scale)) + with torch.no_grad(): + c = F.interpolate(mlog, (h2, w2), mode="bilinear", align_corners=False) + c = F.interpolate(c, (H, W), mode="bilinear", align_corners=False) + return 0.7 * c + 0.3 * mlog + + # ----------------------------- + # Two-pass refinement (only when use_mask=True): + # Pass-1: points/box only → predict mask (no grad), then use it as the mask prompt for Pass-2. + # If Pass-1 looks junky, fall back to a softened GT prompt. + # ----------------------------- + mask_prompt_final = None + if use_mask: + use_pred_mask = torch.rand(1, device=self.device) < 0.5 # 50/50 pred vs GT prompt + + if use_pred_mask: + # PASS 1 (no mask prompt): encode prompts and decode once + sp1, dn1 = self.predictor.model.sam_prompt_encoder( + points=(p_coords, p_labels), boxes=box_in, masks=None + ) + with torch.no_grad(): # do not backprop through Pass-1 + low1, sc1, obj1, _ = self.predictor.model.sam_mask_decoder( + image_embeddings=img_emb, + image_pe=self.predictor.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sp1, + dense_prompt_embeddings=dn1, + multimask_output=True, # allow K proposals; we’ll pick best by score + repeat_image=False, + high_res_features=hr_sub, + ) + # pick best candidate by IoU score head + best_k = torch.argmax(sc1, dim=1) # [n] + n, k, h, w = low1.shape + gather_idx = best_k.view(n, 1, 1, 1).expand(n, 1, h, w) + pred_mask_in = low1.gather(1, gather_idx).detach() # [n,1,h,w], stop-grad + mask_prompt_final = _soften_coarsen(_resize_to_target(pred_mask_in)) + + # Fallback to GT prompt if predicted prompt is unusable (empty/near-zero) + if mask_prompt_final is None or (mask_prompt_final.abs().max() < 1e-4): + mask_prompt_final = _soften_coarsen(base_mask_in) + else: + # Use GT-derived prompt (softened) directly + mask_prompt_final = _soften_coarsen(base_mask_in) + # else: no mask prompt path, keep None + + # PASS 2: final decode (this is the ONLY output we train on) + sp_emb, dn_emb = self.predictor.model.sam_prompt_encoder( + points=(p_coords, p_labels), + boxes=box_in, + masks=mask_prompt_final, # None (no-mask path) or predicted/GT-mask prompt + ) + low, sc, obj, _ = self.predictor.model.sam_mask_decoder( + image_embeddings=img_emb, + image_pe=self.predictor.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sp_emb, + dense_prompt_embeddings=dn_emb, + multimask_output=self.predict_multimask, + repeat_image=False, + high_res_features=hr_sub, + ) + return (idxs, low, sc, obj) + + outs = [] + outs.append(_run_group(groups[0], use_box=False, use_mask=False)) # points + outs.append(_run_group(groups[1], use_box=True, use_mask=False)) # box+points + outs.append(_run_group(groups[2], use_box=False, use_mask=True)) # mask+points + outs.append(_run_group(groups[3], use_box=True, use_mask=True)) # mask+box+points + outs = [o for o in outs if o is not None] + + # (C) stitch back to original order + order = [] + low_list, score_list, obj_list = [], [], [] + for idxs, low, sc, obj in outs: + order.extend(idxs) + low_list.append(low); score_list.append(sc); obj_list.append(obj) + perm = torch.tensor(order, device=self.device).argsort() + + low_res_masks = torch.cat(low_list, dim=0)[perm] + prd_scores = torch.cat(score_list, dim=0)[perm] + obj_logits = torch.cat(obj_list, dim=0)[perm] + + # (D) upsample and finish + target_sizes = [self.predictor._orig_hw[int(b)] for b in inst_img_ix] + upsampled = [] + for i in range(low_res_masks.shape[0]): + H, W = target_sizes[i] + up_i = self.predictor._transforms.postprocess_masks(low_res_masks[i:i+1], (H, W)) + upsampled.append(up_i) + prd_masks = torch.cat(upsampled, dim=0) + + gt_masks = torch.stack(gt_all, dim=0).float() + return prd_masks, prd_scores, gt_masks, obj_logits, inst_img_ix + + @torch.no_grad() + def validate_step(self, max_images=float('inf'), best_metric='ABIoU'): + """ + Validate the model on the given batch. + """ + + # Set the model to evaluation mode + self.predictor.model.eval() + + # Initialize the AMG + amg = SAM2AutomaticMaskGenerator( + model=self.predictor.model, + **self.amg_kwargs + ) + + # --- local accumulators (tensor) --- + loss_keys = ["loss_total", "loss_iou", "loss_dice", "loss_mask", "loss_class"] + losses_sum = {k: torch.tensor(0.0, device=self.device) for k in loss_keys} + n_inst = torch.tensor(0.0, device=self.device) + n_imgs = torch.tensor(0.0, device=self.device) + + # Initialize the metrics sum + metrics_sum = {k: torch.tensor(0.0, device=self.device) for k in self.metric_keys} + + num_images_seen = 0 + for batch in self.val_loader: + + # Compute Loss on decoder outputs + out = self.forward_step(batch) + if out[0] is None: + continue # no instances in this batch + prd_masks, prd_scores, gt_masks, obj_logits = out[:4] + local_n = torch.tensor(float(gt_masks.shape[0]), device=self.device) + + with self.autocast(): + batch_losses = self.loss_fn(prd_masks, prd_scores, gt_masks, obj_logits) + + if best_metric == 'ABIoU': + m = automask_metrics( + self.predictor, # predictor or predictor.model (your function supports either) + batch["images"], # list[H×W×3] or list[H×W] + batch["masks"], # list[list[H×W]] + top_k=150, + device=self.device, + autocast_ctx=self.autocast, + amg_kwargs=self.amg_kwargs, + ) + else: + m = sam2_metrics(batch, out, amg) + + # means → sums + for k in loss_keys: + # detach→cpu→float to avoid graph + dtype issues + losses_sum[k] += float(batch_losses[k].detach().cpu()) * local_n + n_inst += local_n + + # Weight by number of images so we can average correctly later + img_count = float(m["num_images"]) + for k in self.metric_keys: + metrics_sum[k] += torch.tensor(m[k] * img_count, device=self.device) + n_imgs += torch.tensor(img_count, device=self.device) + + num_images_seen += img_count + if num_images_seen >= max_images: + break + + # Reduce losses across ranks + losses_sum = self._all_reduce_sum(losses_sum) + n_inst = self._all_reduce_sum(n_inst) + n_imgs = self._all_reduce_sum(n_imgs) + metrics_sum = self._all_reduce_sum(metrics_sum) + + # Avoid divide-by-zero + img_denom = max(n_imgs.item(), 1.0) + inst_denom = max(n_inst.item(), 1.0) + + out = { + "loss_total": (losses_sum["loss_total"] / inst_denom).item(), + "loss_iou": (losses_sum["loss_iou"] / inst_denom).item(), + "loss_dice": (losses_sum["loss_dice"] / inst_denom).item(), + "loss_mask": (losses_sum["loss_mask"] / inst_denom).item(), + "loss_class": (losses_sum["loss_class"] / inst_denom).item(), + "num_images": int(img_denom), + } + out.update({k: (metrics_sum[k] / img_denom).item() for k in self.metric_keys}) + + if 'cal_tables' in m and self.is_global_zero: + # (optional) keep the last batch's cal_tables on rank0 for inspection: + out["cal_tables"] = m.get("cal_tables", None) + + return out + + def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 1e4): + """ + Fine Tune SAM2 on the given data. + """ + + # Initialize the loss function + weight_dict = {"loss_mask": 1.0, "loss_dice": 20.0, "loss_iou": 1.0, "loss_class": 1.0} + self.loss_fn = MultiMaskIoULoss( + supervise_all_iou=True, # start focused; re-enable later if you want + all_iou_weight=0.1, + weight_dict=weight_dict, + pred_obj_scores=True, + focal_alpha=0.25, focal_gamma=2.0, + focal_alpha_obj=0.5, focal_gamma_obj=0.0, # objectness as plain-ish BCE + gate_by_objectness=False, + ) + # in trainer.__init__ after building self.loss_fn + self.iou_head_weight_base = 1.0 # for later + self.loss_fn.iou_head_weight = 0.25 # start gentle + + # Initialize the metric keys + if best_metric == 'ABIoU': + self.metric_keys = ['ABIoU'] + else: + self.metric_keys = [ + 'prompt_miou', 'cal_mae', 'cal_brier', 'cal_ece', + 'AR', 'R@10', 'R@50', 'R@100'] + + # Cosine scheduler w/Warmup ---- + self.warmup_epochs = 10 + self.warmup_sched = LinearLR(self.optimizer, start_factor=0.1, total_iters=self.warmup_epochs) + self.cosine_sched = CosineAnnealingLR(self.optimizer, T_max=(num_epochs - self.warmup_epochs), eta_min=1e-6) + self.scheduler = SequentialLR(self.optimizer, [self.warmup_sched, self.cosine_sched], milestones=[self.warmup_epochs]) + + # Progress bar only on rank 0 + if self.is_global_zero: + pbar = tqdm(total=num_epochs, desc='Fine Tuning SAM2', unit='epoch', + leave=True, dynamic_ncols=True) + else: + pbar = None + + self.optimizer.zero_grad() + best_metric_value = float('-inf') + # Main Loop + for epoch in range(num_epochs): + + # Scale the all_iou_weights based on epoch + if epoch < self.warmup_epochs: + self.loss_fn.supervise_all_iou = False + self.loss_fn.all_iou_weight = 0.0 + + else: + self.loss_fn.supervise_all_iou = True + self.loss_fn.all_iou_weight = min(0.05, 0.01*(epoch - self.warmup_epochs + 1)) + # gentle ramp for the IoU head + t = epoch - self.warmup_epochs + 1 + self.loss_fn.iou_head_weight = min(1.0, 0.25 + 0.15 * t) # e.g., 0.25→1.0 over ~6 epochs + + self.predictor.model.train() + for batch in self.train_loader: + out = self.forward_step(batch) + if out[0] is None: + continue + prd_masks, prd_scores, gt_masks, obj_logits = out[:4] + with self.autocast(): + batch_losses = self.loss_fn(prd_masks, prd_scores, gt_masks, obj_logits) + if self.use_fabric: + self.fabric.backward(batch_losses['loss_total']) + else: + batch_losses['loss_total'].backward() + + # number of instances this rank used to compute its per-batch means + _local_n = gt_masks.shape[0] + + # (optional) gradient clip: + if self.use_fabric: + # norm-based clipping (L2) on all params in the optimizer + self.fabric.clip_gradients( + self.predictor.model, + self.optimizer, + max_norm=1.0, + norm_type=2.0 + ) + else: + torch.nn.utils.clip_grad_norm_( + self.predictor.model.parameters(), + 1.0, norm_type=2.0) + + self.optimizer.step() + self.optimizer.zero_grad() + losses = self._reduce_losses(batch_losses, _local_n) + + # Learning Rate Scheduler + self.scheduler.step() + + # Validate + metrics = {} + if (epoch+1) % 1e4 == 0: + metrics['val'] = self.amg_param_tuner() + else: + metrics['val'] = self.validate_step(best_metric=best_metric) + metrics['train'] = losses + + # Print Only on Rank 0 + if self.is_global_zero: + pbar.set_postfix({ + "train_loss": f"{metrics['train']['loss_total']:.4f}", + "val_loss": f"{metrics['val']['loss_total']:.4f}", + f"val_{best_metric}": f"{metrics['val'][best_metric]:.4f}", + }) + pbar.update(1) + + # Save Training Log + metrics['epoch'] = epoch + metrics['lr_mask'] = self.scheduler.get_last_lr()[0] + metrics['lr_prompt'] = self.scheduler.get_last_lr()[1] + save_training_log(metrics, self.save_dir, self.metric_keys) + + # Save Model if best metric is achieved + ckpt = {"model": self.predictor.model.state_dict()} + metric_value = metrics['val'].get(best_metric) + if metric_value > best_metric_value: + best_metric_value = metric_value + torch.save(ckpt, f"{self.save_dir}/best_model.pth") + print(f"Best {best_metric} saved!") + else: + torch.save(ckpt, f"{self.save_dir}/bad_model.pth") + + def _reduce_losses(self, losses, num_elems: int = None): + """ + Reduce the losses across ranks. + """ + key_map = { + "loss_iou": "loss_iou", + "loss_dice": "loss_dice", + "loss_mask": "loss_mask", + "loss_class": "loss_class", + "loss_total": "loss_total", + } + count = torch.tensor(float(num_elems if num_elems is not None else 1.0), device=self.device) + out = {} + if self.use_fabric: + global_count = self.fabric.all_reduce(count, reduce_op="sum") + for long_k, short_k in key_map.items(): + if long_k not in losses: + continue + num = torch.tensor(float(losses[long_k].detach().item()), device=self.device) * count + global_num = self.fabric.all_reduce(num, reduce_op="sum") + out[short_k] = (global_num / torch.clamp(global_count, min=1.0)).item() + else: + for long_k, short_k in key_map.items(): + if long_k in losses: + out[short_k] = float(losses[long_k].detach().item()) + return out + + def _all_reduce_sum(self, x): + if not self.use_fabric: + return x + if isinstance(x, torch.Tensor): + return self.fabric.all_reduce(x, reduce_op="sum") + if isinstance(x, dict): + return {k: self.fabric.all_reduce(v, reduce_op="sum") for k, v in x.items()} + raise TypeError(f"_all_reduce_sum expects Tensor or dict[str,Tensor], got {type(x)}") + + # def _all_reduce_sum(self, x: torch.Tensor) -> torch.Tensor: + # """ + # """ + # return self.fabric.all_reduce(x, reduce_op="sum") if self.use_fabric else x + +############### Experimental - Automatic Mask Generator Tuning ############### + + def amg_param_tuner(self, n_trials=10): + """ + Tune a few AMG thresholds with Bayesian optimization (TPE). + Warm-start from the current self.amg_kwargs. + """ + + if self.use_fabric and not self.is_global_zero: + # Non-zero ranks: wait for rank 0 to finish tuning and broadcast params. + self.fabric.barrier() + # Receive updated dict from rank 0 + self.amg_kwargs = self.fabric.broadcast(self.amg_kwargs, src=0) + # Now run normal distributed validation so logs are comparable + return self.validate_step() + + + # Use a fixed sampler (seed for reproducibility) + study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=0)) + + # ---- Warm start with current params ---- + # IMPORTANT: names must match suggest_* names in objective + warm = { + "pred_iou_thresh": float(self.amg_kwargs["pred_iou_thresh"]), + "stability_score_thresh": float(self.amg_kwargs["stability_score_thresh"]), + "stability_score_offset": float(self.amg_kwargs["stability_score_offset"]), + } + study.enqueue_trial(warm) + + def objective(trial: optuna.Trial) -> float: + """ + Objective for Optuna: maximize ABIoU on a held-out validation set, + varying only a few AMG thresholds. Do NOT mutate self.amg_kwargs here. + """ + # Suggest in sensible finetuned ranges + pred_iou_thresh = trial.suggest_float("pred_iou_thresh", 0.40, 0.75) + stability_score_thresh = trial.suggest_float("stability_score_thresh", 0.55, 0.90) + stability_score_offset = trial.suggest_float("stability_score_offset", 0.00, 0.30) + + # Build a LOCAL kwargs dict (copy), keep other knobs fixed + amg_kwargs_trial = dict(self.amg_kwargs) + amg_kwargs_trial.update({ + "pred_iou_thresh": pred_iou_thresh, + "stability_score_thresh": stability_score_thresh, + "stability_score_offset": stability_score_offset, + }) + + # Validate the model with the new AMG thresholds + metrics = self.validate_step( + amg_kwargs_trial, max_images=100, + reduce_all_ranks=not self.use_fabric + ) + return metrics['ABIoU'] + + # Optimize + study.optimize(objective, n_trials=n_trials, show_progress_bar=False) + + # Update trainer state with the best params + best = study.best_params + self.amg_kwargs.update({ + "pred_iou_thresh": float(best["pred_iou_thresh"]), + "stability_score_thresh": float(best["stability_score_thresh"]), + "stability_score_offset": float(best["stability_score_offset"]), + }) + + # Let other ranks proceed and receive the dict + if self.use_fabric: + self.fabric.barrier() + self.amg_kwargs = self.fabric.broadcast(self.amg_kwargs, src=0) + + + if self.is_global_zero: + print("AMG tuned →", {k: self.amg_kwargs[k] for k in ["pred_iou_thresh","stability_score_thresh","stability_score_offset"]}) + + # Now run the normal distributed validate_step() so metrics are globally averaged + return self.validate_step() \ No newline at end of file diff --git a/saber/main.py b/saber/main.py index 19169e4..768d888 100644 --- a/saber/main.py +++ b/saber/main.py @@ -1,17 +1,14 @@ -from saber.entry_points.segment_methods import methods as segment from saber.classifier.cli import classifier_routines as classifier from saber.entry_points.run_low_pass_filter import cli as filter3d +from saber.entry_points.segment_methods import methods as segment +from saber.finetune.cli import finetune_routines as finetune from saber.analysis.analysis_cli import methods as analysis from saber.entry_points.run_analysis import cli as save -from saber.pretrained_weights import cli as download -from saber.utils.importers import cli as importers import click try: from saber.gui.base.zarr_gui import gui - from saber.gui.text.zarr_text_gui import text_gui gui_available = True except Exception as e: - print(f"GUI is not available: {e}") gui_available = False @click.group() @@ -23,9 +20,9 @@ def routines(): routines.add_command(analysis) routines.add_command(filter3d) routines.add_command(classifier) +routines.add_command(finetune) if gui_available: routines.add_command(gui) - routines.add_command(text_gui) routines.add_command(segment) routines.add_command(save) diff --git a/saber/utils/preprocessing.py b/saber/utils/preprocessing.py index f024cd9..1497cf6 100644 --- a/saber/utils/preprocessing.py +++ b/saber/utils/preprocessing.py @@ -3,7 +3,7 @@ def contrast(image, std_cutoff=5): """ - Normalize the Input Data to [0,1] + Clip the Image by ±5std """ image_mean = uniform_filter(image, size=500) image_sq = uniform_filter(image**2, size=500) @@ -14,7 +14,9 @@ def contrast(image, std_cutoff=5): return np.clip(image, -std_cutoff, std_cutoff) def normalize(image, rgb = False): - # Clip the Volume by ±5std + """ + Normalize the Input Data to [0,1] + """ if rgb: min_vals = image.min(axis=(0, 1), keepdims=True) max_vals = image.max(axis=(0, 1), keepdims=True) @@ -24,6 +26,23 @@ def normalize(image, rgb = False): normalized = (image - min_vals) / (max_vals - min_vals + 1e-8) # Add epsilon to avoid div by zero return normalized +def proprocess(image: np.ndarray, std_cutoff=3, rgb=False): + """ + Preprocesses an image for segmentation. + + Parameters: + image (np.ndarray): 3D image array (z, y, x). + std_cutoff (int, optional): Standard deviation cutoff for contrast normalization. + rgb (bool, optional): Whether the input image is RGB. + + Returns: + np.ndarray: Preprocessed image. + """ + image = contrast(image, std_cutoff=std_cutoff) + image = normalize(image, rgb=rgb) + image = image if rgb else np.repeat(image[..., None], 3, axis=2) + return image + def project_tomogram(vol, zSlice = None, deltaZ = None): """ Projects a tomogram along the z-axis. @@ -51,3 +70,24 @@ def project_tomogram(vol, zSlice = None, deltaZ = None): projection = np.mean(vol, axis=0) return projection + +def project_segmentation(vol, zSlice = None, deltaZ = None): + """ + Projects a segmentation along the z-axis. + + Parameters: + vol (np.ndarray): 3D segmentation array (z, y, x). + zSlice (int, optional): Specific z-slice to project. If None, project along all z slices. + deltaZ (int, optional): Thickness of slices to project. Used only if zSlice is specified. If None, project a single slice. + """ + + if zSlice is not None: + if deltaZ is not None: + zStart = int(max(zSlice - deltaZ, 0)) + zEnd = int(min(zSlice + deltaZ, vol.shape[0])) + projection = np.max(vol[zStart:zEnd,], axis=0) + else: + projection = np.max(vol[zSlice,], axis=0) + else: + projection = np.max(vol, axis=0) + return projection \ No newline at end of file diff --git a/saber/visualization/classifier.py b/saber/visualization/classifier.py index 8445a4e..149198a 100755 --- a/saber/visualization/classifier.py +++ b/saber/visualization/classifier.py @@ -1,5 +1,5 @@ +from matplotlib.colors import ListedColormap, hsv_to_rgb from matplotlib.widgets import TextBox, Button -from matplotlib.colors import ListedColormap import matplotlib.pyplot as plt import numpy as np @@ -29,7 +29,8 @@ def display_mask_list(image: np.ndarray, masks: list, save_button: bool = False) display_mask_array(image, masks, save_button) def display_mask_array(image: np.ndarray, masks: np.ndarray, save_button: bool = False): - colors = get_colors() + n_labels = int(np.max(masks)) + colors = get_colors(n_needed=n_labels) # Create figure with extra space for widgets fig = plt.figure(figsize=(9, 7)) @@ -38,7 +39,7 @@ def display_mask_array(image: np.ndarray, masks: np.ndarray, save_button: bool = ax_img = plt.axes([0.1, 0.2, 0.8, 0.75]) ax_img.imshow(image, cmap='gray') - cmap_colors = [(1, 1, 1, 0)] + colors[:np.max(masks)] + cmap_colors = [(1, 1, 1, 0)] + colors cmap = ListedColormap(cmap_colors) ax_img.imshow(masks, cmap=cmap, alpha=0.6) ax_img.axis('off') @@ -303,45 +304,58 @@ def plot_per_class_metrics(per_class_results, save_path=None): else: plt.show() -def get_colors(): +def get_colors(n_needed=None, alpha=0.5): # Extended vibrant color palette - colors = [ - (0, 1, 1, 0.5), # Cyan (bright, high contrast) - (1, 0, 1, 0.5), # Magenta - (0, 0, 1, 0.5), # Blue - (0, 1, 0, 0.5), # Green - - (1, 0.5, 0, 0.5), # Orange - (0.5, 0, 0.5, 0.5), # Purple - (0.2, 0.6, 0.9, 0.5), # Sky Blue - (0.9, 0.2, 0.6, 0.5), # Hot Pink - (0.6, 0.2, 0.8, 0.5), # Violet - - (0.4, 0.7, 0.2, 0.5), # Lime - (0.8, 0.4, 0, 0.5), # Burnt Orange - (0, 0.5, 0, 0.5), # Dark Green - (0.7, 0.3, 0.6, 0.5), # Orchid - (0.9, 0.6, 0.2, 0.5), # Gold - - (1, 1, 0.3, 0.5), # Yellow - (0.5, 0.5, 0, 0.5), # Olive - (0, 0, 0.5, 0.5), # Navy - (0.5, 0, 0, 0.5), # Maroon + base = [ + (0, 1, 1, alpha), # Cyan (bright, high contrast) + (1, 0, 1, alpha), # Magenta + (0, 0, 1, alpha), # Blue + (0, 1, 0, alpha), # Green + + (1, 0.5, 0, alpha), # Orange + (0.5, 0, 0.5, alpha), # Purple + (0.2, 0.6, 0.9, alpha), # Sky Blue + (0.9, 0.2, 0.6, alpha), # Hot Pink + (0.6, 0.2, 0.8, alpha), # Violet + + (0.4, 0.7, 0.2, alpha), # Lime + (0.8, 0.4, 0, alpha), # Burnt Orange + (0, 0.5, 0, alpha), # Dark Green + (0.7, 0.3, 0.6, alpha), # Orchid + (0.9, 0.6, 0.2, alpha), # Gold + + (1, 1, 0.3, alpha), # Yellow + (0.5, 0.5, 0, alpha), # Olive + (0, 0, 0.5, alpha), # Navy + (0.5, 0, 0, alpha), # Maroon # Pastel shades (can be used for less prominent classes) - (1, 0.7, 0.7, 0.5), # Light Red/Pink - (0.7, 1, 0.7, 0.5), # Light Green - (0.7, 0.7, 1, 0.5), # Light Blue - (1, 1, 0.7, 0.5), # Light Yellow + (1, 0.7, 0.7, alpha), # Light Red/Pink + (0.7, 1, 0.7, alpha), # Light Green + (0.7, 0.7, 1, alpha), # Light Blue + (1, 1, 0.7, alpha), # Light Yellow ] - return colors + if n_needed is None: + return base + n_needed = int(n_needed) + if n_needed <= len(base): + return base[:n_needed] + + extra_needed = n_needed - len(base) + hues = np.linspace(0, 1, extra_needed, endpoint=False) + extra = [tuple(hsv_to_rgb((h, 0.9, 0.9))) + (alpha,) for h in hues] + + return base + extra def add_masks(masks, ax): + # Get number of masks + num_masks = masks.shape[0] + # Get colors - colors = get_colors() + colors = get_colors(n_needed=num_masks) # Get number of masks num_masks = masks.shape[0]