From 75731689ba4735b6f5a131801d2971ec233d785f Mon Sep 17 00:00:00 2001 From: root Date: Mon, 15 Sep 2025 18:40:50 +0000 Subject: [PATCH 01/15] Begin adding the worklow for fine tuning sam2 on experimental data --- saber/finetune/__init__.py | 0 saber/finetune/dataset.py | 231 ++++++++++++++++++++++++++ saber/finetune/helper.py | 120 ++++++++++++++ saber/finetune/losses.py | 307 +++++++++++++++++++++++++++++++++++ saber/finetune/train.py | 68 ++++++++ saber/finetune/trainer.py | 33 ++++ saber/main.py | 8 +- saber/utils/preprocessing.py | 44 ++++- 8 files changed, 804 insertions(+), 7 deletions(-) create mode 100644 saber/finetune/__init__.py create mode 100644 saber/finetune/dataset.py create mode 100644 saber/finetune/helper.py create mode 100644 saber/finetune/losses.py create mode 100644 saber/finetune/train.py create mode 100644 saber/finetune/trainer.py diff --git a/saber/finetune/__init__.py b/saber/finetune/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/saber/finetune/dataset.py b/saber/finetune/dataset.py new file mode 100644 index 0000000..9fdc4b6 --- /dev/null +++ b/saber/finetune/dataset.py @@ -0,0 +1,231 @@ +import saber.finetune.helper as helper +from saber.utils import preprocessing +from torch.utils.data import Dataset +import numpy as np +import zarr, torch + +class AutoMaskDataset(Dataset): + def __init__(self, + tomogram_zarr_path: str = None, + fib_zarr_path: str = None, + transform = None, + slabs_per_volume_per_epoch: int = 10, + slices_per_fib_per_epoch: int = 5, + slab_thickness: int = 5): + """ + Args: + tomogram_zarr_path: Path to the tomogram zarr store + fib_zarr_path: Path to the fib zarr store + transform: Transform to apply to the data + slabs_per_volume_per_epoch: Number of slabs per volume per epoch + slices_per_fib_per_epoch: Number of slices per fib per epoch + slab_thickness: Thickness of the slab + """ + + # Slabs per Epoch + self.slab_thickness = slab_thickness + self.slabs_per_volume_per_epoch = slabs_per_volume_per_epoch + self.slices_per_fib_per_epoch = slices_per_fib_per_epoch + + # Grid and Positive Points for AutoMaskGenerator + self.points_per_side = 32 + self.points_per_batch = 64 + self.min_area = 0.001 + self.k_pos = 2 + + # Check if both data types are available + if tomogram_zarr_path is None and fib_zarr_path is None: + raise ValueError("At least one of tomogram_zarr_path or fib_zarr_path must be provided") + + # Flags to track which data types are available + self.has_tomogram = tomogram_zarr_path is not None + self.has_fib = fib_zarr_path is not None + + # Initialize tomogram data if provided + if self.has_tomogram: + self.tomogram_store = zarr.open(tomogram_zarr_path, mode='r') + self.tomogram_keys = [k for k in self.tomogram_store.keys() if not k.startswith('.')] + self.n_tomogram_volumes = len(self.tomogram_keys) + self.tomo_shapes = {} + for i, key in enumerate(self.tomogram_keys): + self.tomo_shapes[i] = self.tomogram_store[key]['0'].shape + else: + self.n_tomogram_volumes = 0 + self.tomo_shapes = {} + self.tomogram_keys = [] + + # Initialize fib data if provided + if self.has_fib: + self.fib_store = zarr.open(fib_zarr_path, mode='r') + self.fib_keys = [k for k in self.fib_store.keys() if not k.startswith('.')] + self.n_fib_volumes = len(self.fib_keys) + self.fib_shapes = {} + for i, key in enumerate(self.fib_keys): + self.fib_shapes[i] = self.fib_store[key]['0'].shape + else: + self.n_fib_volumes = 0 + self.fib_shapes = {} + self.fib_keys = [] + + # Resample epoch + self.resample_epoch() + + # Verbose Flag + self.verbose = False + + def resample_epoch(self): + """ Generate new random samples for this epoch """ + self.tomogram_samples = [] + self.fib_samples = [] + + # Sample random slabs from each tomogram + if self.has_tomogram: + for vol_idx in range(self.n_tomogram_volumes): + volume_shape = self.tomo_shapes[vol_idx] + # Valid range for center of slab + valid_z_min = int(volume_shape[0] / 4) + valid_z_max = int(volume_shape[0] * (3 / 4)) + + if valid_z_max > valid_z_min: + z_positions = np.random.randint( + valid_z_min, + valid_z_max, + size=self.slabs_per_volume_per_epoch + ) + + for z_pos in z_positions: + self.tomogram_samples.append((vol_idx, z_pos)) + np.random.shuffle(self.tomogram_samples) # Shuffle samples + + # Sample random slices from each FIB volume + if self.has_fib: + for fib_idx in range(self.n_fib_volumes): + fib_shape = self.fib_shapes[fib_idx] + # Sample random z positions from this FIB volume + z_positions = np.random.randint( + 0, + fib_shape[0], + size=self.slices_per_fib_per_epoch + ) + + for z_pos in z_positions: + self.fib_samples.append((fib_idx, z_pos)) + np.random.shuffle(self.fib_samples) # Shuffle samples + + # Set epoch length + self.epoch_length = len(self.tomogram_samples) + len(self.fib_samples) + + def __len__(self): + return self.epoch_length + + def __getitem__(self, idx): + + # Get item from tomogram or FIB + if idx < len(self.tomogram_samples) and self.has_tomogram: + return self._get_tomogram_item(idx) + else: + return self._get_fib_item(idx - len(self.tomogram_samples)) + + def _get_tomogram_item(self, idx): + + # Randomly select a tomogram volume + vol_idx, z_pos = self.tomogram_samples[idx] + key = self.tomogram_keys[vol_idx] + + # Load image and segmentation slab + z_start = z_pos - self.slab_thickness // 2 + z_end = z_pos + self.slab_thickness // 2 + 1 + image_slab = self.tomogram_store[key]['0'][z_start:z_end] + seg_slab = self.tomogram_store[key]['labels/0'][z_start:z_end] + + # Project slab and normalize + image_projection = preprocessing.project_tomogram(image_slab) + image_2d = preprocessing.proprocess(image_projection) # 3xHxW + + # Project segmentation + seg_2d = preprocessing.project_segmentation(seg_slab) # HxW + + return self._package_image_item(image_2d, seg_2d) + + def _get_fib_item(self, idx): + + # Randomly select a FIB volume + fib_idx, z_pos = self.fib_samples[idx] + key = self.fib_keys[fib_idx] + + # Load FIB image and segmentation + image = self.fib_store[key]['0'][z_pos,] + image_2d = preprocessing.proprocess(image) + seg_2d = self.fib_store[key]['labels/0'][z_pos,] + + return self._package_image_item(image_2d, seg_2d) + + def _gen_grid_points(self, h: int, w: int) -> np.ndarray: + """ + Generate grid points for a given image size + """ + xs = np.linspace(0.5, w - 0.5, self.points_per_side, dtype=np.float32) + ys = np.linspace(0.5, h - 0.5, self.points_per_side, dtype=np.float32) + xx, yy = np.meshgrid(xs, ys) + return np.stack([xx.ravel(), yy.ravel()], axis=1) # (G,2) as (x,y) + + def _package_image_item(self, + image_2d: np.ndarray, + segmentation: np.ndarray): + """ + Build per-component targets using grid-hit instances. + - Splits each hit instance into connected components. + - Drops components smaller than min_area_frac * (H*W). + - Emits only positive clicks + boxes (no negatives). + Returns: + { + "image": HxWx3 uint8, + "masks": list[H x W] float32 in {0,1}, + "points": list[#p x 2] float32 (xy), + "labels": list[#p] float32 (all ones), + "boxes": list[4] float32 (x0,y0,x1,y1) + } + """ + + h, w = segmentation.shape + min_pixels = 0 + # min_pixels = int(self.min_area * h * w) + + # which instances to train on for this image + grid_points = self._gen_grid_points(h, w) + inst_ids = helper.instances_from_grid(grid_points, segmentation) + masks_t, points_t, labels_t, boxes_t = [], [], [], [] + + for iid in inst_ids: + comps = helper.components_for_id(segmentation, iid, min_pixels) + for comp in comps: + # box from this component + box = helper.mask_to_box(comp) + if box is None: + continue + + # sample clicks from this component (NOT the full instance) + pts = helper.sample_positive_points(comp, k=self.k_pos) + if pts.shape[0] == 0: + continue + + masks_t.append(torch.from_numpy(comp.astype(np.float32))) + points_t.append(torch.from_numpy(pts.astype(np.float32))) + labels_t.append(torch.from_numpy(np.ones((pts.shape[0],), dtype=np.float32))) + boxes_t.append(torch.from_numpy(box.astype(np.float32))) + + # fallback to a harmless dummy if nothing was hit by the grid (keeps loader stable) + if len(masks_t) == 0: + print("No masks found") + masks_t = [torch.from_numpy(np.zeros_like(segmentation, dtype=np.float32))] + points_t = [torch.from_numpy(np.zeros((1, 2), dtype=np.float32))] + labels_t = [torch.from_numpy(np.ones((1,), dtype=np.float32))] + boxes_t = [torch.from_numpy(np.array([0, 0, 1, 1], dtype=np.float32))] + + return { + "image": image_2d, # HxWx3 uint8 + "masks": masks_t, # list[H x W] float32 in {0,1} + "points": points_t, # list[#p x 2] float32 (xy) + "labels": labels_t, # list[#p] all ones + "boxes": boxes_t, # list[4] (x0,y0,x1,y1) + } \ No newline at end of file diff --git a/saber/finetune/helper.py b/saber/finetune/helper.py new file mode 100644 index 0000000..e3e8e5e --- /dev/null +++ b/saber/finetune/helper.py @@ -0,0 +1,120 @@ +from saber.visualization.classifier import get_colors, add_masks +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +import cv2 + +def mask_to_box(mask: np.ndarray) -> np.ndarray | None: + """xyxy box from a binary mask (H,W) in {0,1}.""" + ys, xs = np.where(mask > 0) + if xs.size == 0: + return None + return np.array([xs.min(), ys.min(), xs.max(), ys.max()], dtype=np.float32) + +def sample_positive_points(mask: np.ndarray, k: int = 1) -> np.ndarray: + """Sample k (x,y) positives uniformly from mask pixels.""" + ys, xs = np.where(mask > 0) + if xs.size == 0: + return np.zeros((0, 2), dtype=np.float32) + idx = np.random.randint(0, xs.size, size=k) + return np.stack([xs[idx], ys[idx]], axis=1).astype(np.float32) + +def instances_from_grid(grid_points: np.ndarray, segmentation: np.ndarray) -> list[int]: + """Return unique non-zero instance IDs at the given (x,y) grid sample points.""" + h, w = segmentation.shape + xs = np.clip(grid_points[:, 0].astype(np.int32), 0, w - 1) + ys = np.clip(grid_points[:, 1].astype(np.int32), 0, h - 1) + ids = segmentation[ys, xs] + ids = ids[ids > 0] + return np.unique(ids).astype(int).tolist() + +# helper: split an instance id into connected components +def components_for_id(seg, iid: int, min_pixels: int): + mask_u8 = (seg == iid).astype(np.uint8) # 0/1 + num, lbl = cv2.connectedComponents(mask_u8) + comps = [] + for cid in range(1, num): # skip background=0 + comp = (lbl == cid).astype(np.float32) # HxW {0,1} + if comp.sum() >= min_pixels: + comps.append(comp) + return comps + +def collate_autoseg(batch): + # batch: list of dicts from _package_image_item + return { + "images": [b["image"] for b in batch], # list of HxWx3 uint8 + "masks": [b["masks"] for b in batch], # list of list[H x W] + "points": [b["points"] for b in batch], # list of list[#p x 2] + "labels": [b["labels"] for b in batch], # list of list[#p] + "boxes": [b["boxes"] for b in batch], # list of list[4] + } + +def _to_numpy_mask_stack(masks): + """ + Accepts list/tuple of tensors or np arrays shaped [H,W]; + returns np.uint8 array [N,H,W] with values {0,1}. + """ + if isinstance(masks, np.ndarray) and masks.ndim == 3: + arr = masks + else: + arr = np.stack([m.detach().cpu().numpy() if hasattr(m, "detach") else np.asarray(m) + for m in masks], axis=0) + # binarize & cast + arr = (arr > 0).astype(np.uint8) + return arr + +def visualize_item_with_points(image, masks, points, boxes=None, + title=None, point_size=24): + """ + Show a single image with all component masks (colored), + all positive points (color-matched to masks), and optional boxes. + + Args + ---- + image : np.ndarray (H,W) or (H,W,3) (uint8 or float) + masks : list[np.ndarray or torch.Tensor] each [H,W] in {0,1} + points: list[np.ndarray] each [P,2] as (x,y) + boxes : list[np.ndarray] each [4] xyxy (optional) + """ + # normalize inputs + mstack = _to_numpy_mask_stack(masks) # [N,H,W] in {0,1} + + # set up figure + fig, ax = plt.subplots(1, 1, figsize=(5, 5)) + if image.ndim == 2: + ax.imshow(image, cmap="gray", interpolation="nearest") + else: + ax.imshow(image, interpolation="nearest") + + # overlay masks with your helper + add_masks(mstack, ax) # uses get_colors internally + + # color-match points (and boxes) to mask color + colors = get_colors() + for i, pts in enumerate(points): + if pts is None or len(pts) == 0: + continue + color = colors[i % len(colors)] + ax.scatter(pts[:, 0], pts[:, 1], s=point_size, + c=[(color[0], color[1], color[2], 1.0)], + edgecolors="k", linewidths=0.5) + + if boxes is not None: + for i, b in enumerate(boxes): + if b is None: + continue + color = colors[i % len(colors)] + x0, y0, x1, y1 = map(float, b) + rect = patches.Rectangle( + (x0, y0), x1 - x0, y1 - y0, + linewidth=2, + edgecolor=(color[0], color[1], color[2], 1.0), + facecolor="none" + ) + ax.add_patch(rect) + + if title: + ax.set_title(title) + ax.axis("off") + plt.tight_layout() + plt.show() \ No newline at end of file diff --git a/saber/finetune/losses.py b/saber/finetune/losses.py new file mode 100644 index 0000000..d281b1a --- /dev/null +++ b/saber/finetune/losses.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from typing import Dict, List + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F + +from training.trainer import CORE_LOSS_KEY + +from training.utils.distributed import get_world_size, is_dist_avail_and_initialized + + +def dice_loss(inputs, targets, num_objects, loss_on_multimask=False): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + num_objects: Number of objects in the batch + loss_on_multimask: True if multimask prediction is enabled + Returns: + Dice loss tensor + """ + inputs = inputs.sigmoid() + if loss_on_multimask: + # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks + assert inputs.dim() == 4 and targets.dim() == 4 + # flatten spatial dimension while keeping multimask channel dimension + inputs = inputs.flatten(2) + targets = targets.flatten(2) + numerator = 2 * (inputs * targets).sum(-1) + else: + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +def sigmoid_focal_loss( + inputs, + targets, + num_objects, + alpha: float = 0.25, + gamma: float = 2, + loss_on_multimask=False, +): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + num_objects: Number of objects in the batch + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + loss_on_multimask: True if multimask prediction is enabled + Returns: + focal loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if loss_on_multimask: + # loss is [N, M, H, W] where M corresponds to multiple predicted masks + assert loss.dim() == 4 + return loss.flatten(2).mean(-1) / num_objects # average over spatial dims + return loss.mean(1).sum() / num_objects + + +def iou_loss( + inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False +): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + pred_ious: A float tensor containing the predicted IoUs scores per mask + num_objects: Number of objects in the batch + loss_on_multimask: True if multimask prediction is enabled + use_l1_loss: Whether to use L1 loss is used instead of MSE loss + Returns: + IoU loss tensor + """ + assert inputs.dim() == 4 and targets.dim() == 4 + pred_mask = inputs.flatten(2) > 0 + gt_mask = targets.flatten(2) > 0 + area_i = torch.sum(pred_mask & gt_mask, dim=-1).float() + area_u = torch.sum(pred_mask | gt_mask, dim=-1).float() + actual_ious = area_i / torch.clamp(area_u, min=1.0) + + if use_l1_loss: + loss = F.l1_loss(pred_ious, actual_ious, reduction="none") + else: + loss = F.mse_loss(pred_ious, actual_ious, reduction="none") + if loss_on_multimask: + return loss / num_objects + return loss.sum() / num_objects + + +class MultiStepMultiMasksAndIous(nn.Module): + def __init__( + self, + weight_dict, + focal_alpha=0.25, + focal_gamma=2, + supervise_all_iou=False, + iou_use_l1_loss=False, + pred_obj_scores=False, + focal_gamma_obj_score=0.0, + focal_alpha_obj_score=-1, + ): + """ + This class computes the multi-step multi-mask and IoU losses. + Args: + weight_dict: dict containing weights for focal, dice, iou losses + focal_alpha: alpha for sigmoid focal loss + focal_gamma: gamma for sigmoid focal loss + supervise_all_iou: if True, back-prop iou losses for all predicted masks + iou_use_l1_loss: use L1 loss instead of MSE loss for iou + pred_obj_scores: if True, compute loss for object scores + focal_gamma_obj_score: gamma for sigmoid focal loss on object scores + focal_alpha_obj_score: alpha for sigmoid focal loss on object scores + """ + + super().__init__() + self.weight_dict = weight_dict + self.focal_alpha = focal_alpha + self.focal_gamma = focal_gamma + assert "loss_mask" in self.weight_dict + assert "loss_dice" in self.weight_dict + assert "loss_iou" in self.weight_dict + if "loss_class" not in self.weight_dict: + self.weight_dict["loss_class"] = 0.0 + + self.focal_alpha_obj_score = focal_alpha_obj_score + self.focal_gamma_obj_score = focal_gamma_obj_score + self.supervise_all_iou = supervise_all_iou + self.iou_use_l1_loss = iou_use_l1_loss + self.pred_obj_scores = pred_obj_scores + + def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor): + assert len(outs_batch) == len(targets_batch) + num_objects = torch.tensor( + (targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float + ) # Number of objects is fixed within a batch + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_objects) + num_objects = torch.clamp(num_objects / get_world_size(), min=1).item() + + losses = defaultdict(int) + for outs, targets in zip(outs_batch, targets_batch): + cur_losses = self._forward(outs, targets, num_objects) + for k, v in cur_losses.items(): + losses[k] += v + + return losses + + def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects): + """ + Compute the losses related to the masks: the focal loss and the dice loss. + and also the MAE or MSE loss between predicted IoUs and actual IoUs. + + Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors + of shape [N, M, H, W], where M could be 1 or larger, corresponding to + one or multiple predicted masks from a click. + + We back-propagate focal, dice losses only on the prediction channel + with the lowest focal+dice loss between predicted mask and ground-truth. + If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks. + """ + + target_masks = targets.unsqueeze(1).float() + assert target_masks.dim() == 4 # [N, 1, H, W] + src_masks_list = outputs["multistep_pred_multimasks_high_res"] + ious_list = outputs["multistep_pred_ious"] + object_score_logits_list = outputs["multistep_object_score_logits"] + + assert len(src_masks_list) == len(ious_list) + assert len(object_score_logits_list) == len(ious_list) + + # accumulate the loss over prediction steps + losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0} + for src_masks, ious, object_score_logits in zip( + src_masks_list, ious_list, object_score_logits_list + ): + self._update_losses( + losses, src_masks, target_masks, ious, num_objects, object_score_logits + ) + losses[CORE_LOSS_KEY] = self.reduce_loss(losses) + return losses + + def _update_losses( + self, losses, src_masks, target_masks, ious, num_objects, object_score_logits + ): + target_masks = target_masks.expand_as(src_masks) + # get focal, dice and iou loss on all output masks in a prediction step + loss_multimask = sigmoid_focal_loss( + src_masks, + target_masks, + num_objects, + alpha=self.focal_alpha, + gamma=self.focal_gamma, + loss_on_multimask=True, + ) + loss_multidice = dice_loss( + src_masks, target_masks, num_objects, loss_on_multimask=True + ) + if not self.pred_obj_scores: + loss_class = torch.tensor( + 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device + ) + target_obj = torch.ones( + loss_multimask.shape[0], + 1, + dtype=loss_multimask.dtype, + device=loss_multimask.device, + ) + else: + target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[ + ..., None + ].float() + loss_class = sigmoid_focal_loss( + object_score_logits, + target_obj, + num_objects, + alpha=self.focal_alpha_obj_score, + gamma=self.focal_gamma_obj_score, + ) + + loss_multiiou = iou_loss( + src_masks, + target_masks, + ious, + num_objects, + loss_on_multimask=True, + use_l1_loss=self.iou_use_l1_loss, + ) + assert loss_multimask.dim() == 2 + assert loss_multidice.dim() == 2 + assert loss_multiiou.dim() == 2 + if loss_multimask.size(1) > 1: + # take the mask indices with the smallest focal + dice loss for back propagation + loss_combo = ( + loss_multimask * self.weight_dict["loss_mask"] + + loss_multidice * self.weight_dict["loss_dice"] + ) + best_loss_inds = torch.argmin(loss_combo, dim=-1) + batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device) + loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1) + loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1) + # calculate the iou prediction and slot losses only in the index + # with the minimum loss for each mask (to be consistent w/ SAM) + if self.supervise_all_iou: + loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1) + else: + loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1) + else: + loss_mask = loss_multimask + loss_dice = loss_multidice + loss_iou = loss_multiiou + + # backprop focal, dice and iou loss only if obj present + loss_mask = loss_mask * target_obj + loss_dice = loss_dice * target_obj + loss_iou = loss_iou * target_obj + + # sum over batch dimension (note that the losses are already divided by num_objects) + losses["loss_mask"] += loss_mask.sum() + losses["loss_dice"] += loss_dice.sum() + losses["loss_iou"] += loss_iou.sum() + losses["loss_class"] += loss_class + + def reduce_loss(self, losses): + reduced_loss = 0.0 + for loss_key, weight in self.weight_dict.items(): + if loss_key not in losses: + raise ValueError(f"{type(self)} doesn't compute {loss_key}") + if weight != 0: + reduced_loss += losses[loss_key] * weight + + return reduced_loss diff --git a/saber/finetune/train.py b/saber/finetune/train.py new file mode 100644 index 0000000..5f4e2ad --- /dev/null +++ b/saber/finetune/train.py @@ -0,0 +1,68 @@ +from sam2.sam2_image_predictor import SAM2ImagePredictor +from saber.finetune.trainer import SAM2FinetuneTrainer +from saber.finetune.dataset import AutoMaskDataset +from saber.finetune.helper import collate_autoseg +from saber.utils.slurm_submit import sam2_inputs +from torch.utils.data import DataLoader +from sam2.build_sam import build_sam2 +from saber import pretrained_weights +from saber.utils import io +import click + +def finetune_sam2( + tomo_train: str = None, + fib_train: str = None, + tomo_val: str = None, + fib_val: str = None, + sam2_cfg: str = 'base', + deviceID: int = 0, + num_epochs: int = 1000): + """ + Finetune SAM2 on tomograms and FIBs + """ + + # Determine device + device = io.get_available_devices(deviceID) + + (cfg, checkpoint) = pretrained_weights.get_sam2_checkpoint(sam2_cfg) + sam2_model = build_sam2(cfg, checkpoint, device=device, postprocess_mask=False) + predictor = SAM2ImagePredictor(sam2_model) + + # Option 1 : Train the Mask Decoder and Prompt Encoder + predictor.model.sam_mask_decoder.train(True) + predictor.model.sam_prompt_encoder.train(True) + + # Load data loaders + train_dataset = DataLoader(AutoMaskDataset(tomo_train, fib_train), batch_size=16, shuffle=True, + num_workers=4, pin_memory=True, collate_fn=collate_autoseg) + val_dataset = DataLoader(AutoMaskDataset(tomo_val, fib_val), batch_size=16, shuffle=False, + num_workers=4, pin_memory=True, collate_fn=collate_autoseg) + + # Initialize trainer and train + trainer = SAM2FinetuneTrainer(predictor, train_dataset, val_dataset, device) + trainer.train(train_dataset, val_dataset, num_epochs) + + # Save Results and Model + +@click.command() +@sam2_inputs +@click.option("--epochs", type=int, default=10, help="Number of epochs to train for") +@click.option("--train-zarr", type=str, help="Path to train Zarr") +@click.option("--val-zarr", type=str, help="Path to val Zarr") +def finetune(sam2_cfg: str, deviceID: int, num_epochs: int, train_zarr: str, val_zarr: str): + """ + Finetune SAM2 on 3D Volumes. Images from input tomograms and fibs are generated with slabs and slices, respectively. + """ + + print("--------------------------------") + print( + f"Fine Tuning SAM2 on {train_zarr} and {val_zarr} for {num_epochs} epochs" + ) + print(f"Using SAM2 Config: {sam2_cfg}") + print(f"Using Device: {deviceID}") + print(f"Using Number of Epochs: {num_epochs}") + print(f"Using Train Zarr: {train_zarr}") + print(f"Using Val Zarr: {val_zarr}") + print("--------------------------------") + + finetune_sam2(train_zarr, val_zarr, sam2_cfg, deviceID, num_epochs) \ No newline at end of file diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py new file mode 100644 index 0000000..2e0a86d --- /dev/null +++ b/saber/finetune/trainer.py @@ -0,0 +1,33 @@ +from lightning.fabric import Fabric +from tqdm import tqdm +import torch + +class SAM2FinetuneTrainer: + def __init__(self, model, train_loader, val_loader, device): + self.model = model + self.train_loader = train_loader + self.val_loader = val_loader + self.device = device + + # Two parameter groups for different LRs (optional) + params = [ + {"params": [p for p in model.sam_mask_decoder.parameters() if p.requires_grad], + "lr": 1e-4}, + {"params": [p for p in model.sam_prompt_encoder.parameters() if p.requires_grad], + "lr": 5e-5}, + ] + + self.optimizer = torch.optim.AdamW(params, weight_decay=4e-5) + self.scaler = torch.cuda.amp.GradScaler() + + def train(self, train_loader, val_loader, num_epochs): + """ + Fine Tune SAM2 on the given data. + """ + + best_metric_value = -1 + for epoch in tqdm(range(num_epochs)): + + # Reset results for this epoch + epoch_loss_train = 0 + epoch_loss_val = 0 diff --git a/saber/main.py b/saber/main.py index 19169e4..cb700bd 100644 --- a/saber/main.py +++ b/saber/main.py @@ -1,14 +1,12 @@ -from saber.entry_points.segment_methods import methods as segment from saber.classifier.cli import classifier_routines as classifier from saber.entry_points.run_low_pass_filter import cli as filter3d +from saber.entry_points.segment_methods import methods as segment from saber.analysis.analysis_cli import methods as analysis from saber.entry_points.run_analysis import cli as save -from saber.pretrained_weights import cli as download -from saber.utils.importers import cli as importers +from saber.finetune.train import finetune import click try: from saber.gui.base.zarr_gui import gui - from saber.gui.text.zarr_text_gui import text_gui gui_available = True except Exception as e: print(f"GUI is not available: {e}") @@ -23,9 +21,9 @@ def routines(): routines.add_command(analysis) routines.add_command(filter3d) routines.add_command(classifier) +routines.add_command(finetune) if gui_available: routines.add_command(gui) - routines.add_command(text_gui) routines.add_command(segment) routines.add_command(save) diff --git a/saber/utils/preprocessing.py b/saber/utils/preprocessing.py index f024cd9..82697ff 100644 --- a/saber/utils/preprocessing.py +++ b/saber/utils/preprocessing.py @@ -3,7 +3,7 @@ def contrast(image, std_cutoff=5): """ - Normalize the Input Data to [0,1] + Clip the Image by ±5std """ image_mean = uniform_filter(image, size=500) image_sq = uniform_filter(image**2, size=500) @@ -14,7 +14,9 @@ def contrast(image, std_cutoff=5): return np.clip(image, -std_cutoff, std_cutoff) def normalize(image, rgb = False): - # Clip the Volume by ±5std + """ + Normalize the Input Data to [0,1] + """ if rgb: min_vals = image.min(axis=(0, 1), keepdims=True) max_vals = image.max(axis=(0, 1), keepdims=True) @@ -24,6 +26,23 @@ def normalize(image, rgb = False): normalized = (image - min_vals) / (max_vals - min_vals + 1e-8) # Add epsilon to avoid div by zero return normalized +def proprocess(image: np.ndarray, std_cutoff=3, rgb=False): + """ + Preprocesses an image for segmentation. + + Parameters: + image (np.ndarray): 3D image array (z, y, x). + std_cutoff (int, optional): Standard deviation cutoff for contrast normalization. + rgb (bool, optional): Whether the input image is RGB. + + Returns: + np.ndarray: Preprocessed image. + """ + image = contrast(image, std_cutoff=std_cutoff) + image = normalize(image, rgb=rgb) + image = np.repeat(image[..., None], 3, axis=2) if rgb else image + return image + def project_tomogram(vol, zSlice = None, deltaZ = None): """ Projects a tomogram along the z-axis. @@ -51,3 +70,24 @@ def project_tomogram(vol, zSlice = None, deltaZ = None): projection = np.mean(vol, axis=0) return projection + +def project_segmentation(vol, zSlice = None, deltaZ = None): + """ + Projects a segmentation along the z-axis. + + Parameters: + vol (np.ndarray): 3D segmentation array (z, y, x). + zSlice (int, optional): Specific z-slice to project. If None, project along all z slices. + deltaZ (int, optional): Thickness of slices to project. Used only if zSlice is specified. If None, project a single slice. + """ + + if zSlice is not None: + if deltaZ is not None: + zStart = int(max(zSlice - deltaZ, 0)) + zEnd = int(min(zSlice + deltaZ, vol.shape[0])) + projection = np.max(vol[zStart:zEnd,], axis=0) + else: + projection = np.max(vol[zSlice,], axis=0) + else: + projection = np.max(vol, axis=0) + return projection \ No newline at end of file From 6339670285b56d9a6030cc4fa08bc7cddf1085b1 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 15 Sep 2025 18:43:15 +0000 Subject: [PATCH 02/15] add fabrics as a dependency --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9db81ac..b133761 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,11 +28,12 @@ dependencies = [ "monai", "click", "copick", + "kornia", "nibabel", "mrcfile", "starfile", + "lightning", "matplotlib", - "kornia", "opencv-python", "multiprocess", "torchmetrics", From c5e00fcb20466a040560166052595765e3aa729a Mon Sep 17 00:00:00 2001 From: root Date: Mon, 15 Sep 2025 23:41:29 +0000 Subject: [PATCH 03/15] Begin drafting the training loops --- saber/finetune/dataset.py | 5 + saber/finetune/losses.py | 379 +++++++++----------------------------- saber/finetune/train.py | 41 ++--- saber/finetune/trainer.py | 192 +++++++++++++++++-- 4 files changed, 294 insertions(+), 323 deletions(-) diff --git a/saber/finetune/dataset.py b/saber/finetune/dataset.py index 9fdc4b6..4849ead 100644 --- a/saber/finetune/dataset.py +++ b/saber/finetune/dataset.py @@ -222,6 +222,11 @@ def _package_image_item(self, labels_t = [torch.from_numpy(np.ones((1,), dtype=np.float32))] boxes_t = [torch.from_numpy(np.array([0, 0, 1, 1], dtype=np.float32))] + # Apply transforms + if self.transform: + data = self.transform({'image': image_2d, 'masks': masks_t}) + image_2d, masks_t = data['image'], data['masks'] + return { "image": image_2d, # HxWx3 uint8 "masks": masks_t, # list[H x W] float32 in {0,1} diff --git a/saber/finetune/losses.py b/saber/finetune/losses.py index d281b1a..4f9bb98 100644 --- a/saber/finetune/losses.py +++ b/saber/finetune/losses.py @@ -1,307 +1,112 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from collections import defaultdict -from typing import Dict, List - import torch -import torch.distributed import torch.nn as nn import torch.nn.functional as F -from training.trainer import CORE_LOSS_KEY - -from training.utils.distributed import get_world_size, is_dist_avail_and_initialized - - -def dice_loss(inputs, targets, num_objects, loss_on_multimask=False): - """ - Compute the DICE loss, similar to generalized IOU for masks - Args: - inputs: A float tensor of arbitrary shape. - The predictions for each example. - targets: A float tensor with the same shape as inputs. Stores the binary - classification label for each element in inputs - (0 for the negative class and 1 for the positive class). - num_objects: Number of objects in the batch - loss_on_multimask: True if multimask prediction is enabled - Returns: - Dice loss tensor - """ - inputs = inputs.sigmoid() - if loss_on_multimask: - # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks - assert inputs.dim() == 4 and targets.dim() == 4 - # flatten spatial dimension while keeping multimask channel dimension - inputs = inputs.flatten(2) - targets = targets.flatten(2) - numerator = 2 * (inputs * targets).sum(-1) - else: - inputs = inputs.flatten(1) - numerator = 2 * (inputs * targets).sum(1) - denominator = inputs.sum(-1) + targets.sum(-1) - loss = 1 - (numerator + 1) / (denominator + 1) - if loss_on_multimask: - return loss / num_objects - return loss.sum() / num_objects - - -def sigmoid_focal_loss( - inputs, - targets, - num_objects, - alpha: float = 0.25, - gamma: float = 2, - loss_on_multimask=False, -): - """ - Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. - Args: - inputs: A float tensor of arbitrary shape. - The predictions for each example. - targets: A float tensor with the same shape as inputs. Stores the binary - classification label for each element in inputs - (0 for the negative class and 1 for the positive class). - num_objects: Number of objects in the batch - alpha: (optional) Weighting factor in range (0,1) to balance - positive vs negative examples. Default = -1 (no weighting). - gamma: Exponent of the modulating factor (1 - p_t) to - balance easy vs hard examples. - loss_on_multimask: True if multimask prediction is enabled - Returns: - focal loss tensor - """ - prob = inputs.sigmoid() - ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") - p_t = prob * targets + (1 - prob) * (1 - targets) - loss = ce_loss * ((1 - p_t) ** gamma) - - if alpha >= 0: - alpha_t = alpha * targets + (1 - alpha) * (1 - targets) - loss = alpha_t * loss - - if loss_on_multimask: - # loss is [N, M, H, W] where M corresponds to multiple predicted masks - assert loss.dim() == 4 - return loss.flatten(2).mean(-1) / num_objects # average over spatial dims - return loss.mean(1).sum() / num_objects +def dice_loss_from_logits(logits, targets, eps=1e-6): + probs = torch.sigmoid(logits) + inter = (probs * targets).sum(dim=(1, 2)) + denom = probs.sum(dim=(1, 2)) + targets.sum(dim=(1, 2)) + return 1 - (2 * inter + eps) / (denom + eps) +def focal_loss_from_logits(logits, targets, alpha=0.25, gamma=2.0): + ce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none") + p = torch.sigmoid(logits) + pt = p * targets + (1 - p) * (1 - targets) + w = alpha * targets + (1 - alpha) * (1 - targets) + return (w * ((1 - pt) ** gamma) * ce).mean(dim=(1, 2)) -def iou_loss( - inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False -): +class MultiMaskIoULoss(nn.Module): """ - Args: - inputs: A float tensor of arbitrary shape. - The predictions for each example. - targets: A float tensor with the same shape as inputs. Stores the binary - classification label for each element in inputs - (0 for the negative class and 1 for the positive class). - pred_ious: A float tensor containing the predicted IoUs scores per mask - num_objects: Number of objects in the batch - loss_on_multimask: True if multimask prediction is enabled - use_l1_loss: Whether to use L1 loss is used instead of MSE loss - Returns: - IoU loss tensor + General loss for multi-mask predictions with IoU calibration. + Designed for SAM/SAM2 fine-tuning with AMG. """ - assert inputs.dim() == 4 and targets.dim() == 4 - pred_mask = inputs.flatten(2) > 0 - gt_mask = targets.flatten(2) > 0 - area_i = torch.sum(pred_mask & gt_mask, dim=-1).float() - area_u = torch.sum(pred_mask | gt_mask, dim=-1).float() - actual_ious = area_i / torch.clamp(area_u, min=1.0) - - if use_l1_loss: - loss = F.l1_loss(pred_ious, actual_ious, reduction="none") - else: - loss = F.mse_loss(pred_ious, actual_ious, reduction="none") - if loss_on_multimask: - return loss / num_objects - return loss.sum() / num_objects - - -class MultiStepMultiMasksAndIous(nn.Module): - def __init__( - self, - weight_dict, - focal_alpha=0.25, - focal_gamma=2, - supervise_all_iou=False, - iou_use_l1_loss=False, - pred_obj_scores=False, - focal_gamma_obj_score=0.0, - focal_alpha_obj_score=-1, - ): - """ - This class computes the multi-step multi-mask and IoU losses. - Args: - weight_dict: dict containing weights for focal, dice, iou losses - focal_alpha: alpha for sigmoid focal loss - focal_gamma: gamma for sigmoid focal loss - supervise_all_iou: if True, back-prop iou losses for all predicted masks - iou_use_l1_loss: use L1 loss instead of MSE loss for iou - pred_obj_scores: if True, compute loss for object scores - focal_gamma_obj_score: gamma for sigmoid focal loss on object scores - focal_alpha_obj_score: alpha for sigmoid focal loss on object scores - """ + def __init__(self, + weight_dict: dict, + focal_alpha=0.25, + focal_gamma=2.0, + supervise_all_iou=False, + iou_use_l1_loss=True): super().__init__() self.weight_dict = weight_dict + assert "loss_mask" in weight_dict + assert "loss_dice" in weight_dict + assert "loss_iou" in weight_dict self.focal_alpha = focal_alpha self.focal_gamma = focal_gamma - assert "loss_mask" in self.weight_dict - assert "loss_dice" in self.weight_dict - assert "loss_iou" in self.weight_dict - if "loss_class" not in self.weight_dict: - self.weight_dict["loss_class"] = 0.0 - - self.focal_alpha_obj_score = focal_alpha_obj_score - self.focal_gamma_obj_score = focal_gamma_obj_score self.supervise_all_iou = supervise_all_iou self.iou_use_l1_loss = iou_use_l1_loss - self.pred_obj_scores = pred_obj_scores - - def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor): - assert len(outs_batch) == len(targets_batch) - num_objects = torch.tensor( - (targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float - ) # Number of objects is fixed within a batch - if is_dist_avail_and_initialized(): - torch.distributed.all_reduce(num_objects) - num_objects = torch.clamp(num_objects / get_world_size(), min=1).item() - - losses = defaultdict(int) - for outs, targets in zip(outs_batch, targets_batch): - cur_losses = self._forward(outs, targets, num_objects) - for k, v in cur_losses.items(): - losses[k] += v - - return losses - def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects): + def forward(self, prd_masks, prd_scores, gt_masks): """ - Compute the losses related to the masks: the focal loss and the dice loss. - and also the MAE or MSE loss between predicted IoUs and actual IoUs. - - Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors - of shape [N, M, H, W], where M could be 1 or larger, corresponding to - one or multiple predicted masks from a click. - - We back-propagate focal, dice losses only on the prediction channel - with the lowest focal+dice loss between predicted mask and ground-truth. - If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks. + Args + ---- + prd_masks: [N, K, H, W] logits from decoder + prd_scores: [N, K] predicted IoU scores + gt_masks: [N, H, W] float {0,1} """ - - target_masks = targets.unsqueeze(1).float() - assert target_masks.dim() == 4 # [N, 1, H, W] - src_masks_list = outputs["multistep_pred_multimasks_high_res"] - ious_list = outputs["multistep_pred_ious"] - object_score_logits_list = outputs["multistep_object_score_logits"] - - assert len(src_masks_list) == len(ious_list) - assert len(object_score_logits_list) == len(ious_list) - - # accumulate the loss over prediction steps - losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0} - for src_masks, ious, object_score_logits in zip( - src_masks_list, ious_list, object_score_logits_list - ): - self._update_losses( - losses, src_masks, target_masks, ious, num_objects, object_score_logits - ) - losses[CORE_LOSS_KEY] = self.reduce_loss(losses) - return losses - - def _update_losses( - self, losses, src_masks, target_masks, ious, num_objects, object_score_logits - ): - target_masks = target_masks.expand_as(src_masks) - # get focal, dice and iou loss on all output masks in a prediction step - loss_multimask = sigmoid_focal_loss( - src_masks, - target_masks, - num_objects, - alpha=self.focal_alpha, - gamma=self.focal_gamma, - loss_on_multimask=True, - ) - loss_multidice = dice_loss( - src_masks, target_masks, num_objects, loss_on_multimask=True - ) - if not self.pred_obj_scores: - loss_class = torch.tensor( - 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device - ) - target_obj = torch.ones( - loss_multimask.shape[0], - 1, - dtype=loss_multimask.dtype, - device=loss_multimask.device, - ) - else: - target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[ - ..., None - ].float() - loss_class = sigmoid_focal_loss( - object_score_logits, - target_obj, - num_objects, - alpha=self.focal_alpha_obj_score, - gamma=self.focal_gamma_obj_score, - ) - - loss_multiiou = iou_loss( - src_masks, - target_masks, - ious, - num_objects, - loss_on_multimask=True, - use_l1_loss=self.iou_use_l1_loss, - ) - assert loss_multimask.dim() == 2 - assert loss_multidice.dim() == 2 - assert loss_multiiou.dim() == 2 - if loss_multimask.size(1) > 1: - # take the mask indices with the smallest focal + dice loss for back propagation - loss_combo = ( - loss_multimask * self.weight_dict["loss_mask"] - + loss_multidice * self.weight_dict["loss_dice"] - ) - best_loss_inds = torch.argmin(loss_combo, dim=-1) - batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device) - loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1) - loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1) - # calculate the iou prediction and slot losses only in the index - # with the minimum loss for each mask (to be consistent w/ SAM) - if self.supervise_all_iou: - loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1) + N, K, H, W = prd_masks.shape + + # compute per-proposal losses + loss_mask_k, loss_dice_k = [], [] + for k in range(K): + l_focal = focal_loss_from_logits( + prd_masks[:, k], gt_masks, + alpha=self.focal_alpha, gamma=self.focal_gamma + ) # scalar over batch + l_dice = dice_loss_from_logits(prd_masks[:, k], gt_masks) # [N] + loss_mask_k.append(l_focal.expand_as(l_dice)) + loss_dice_k.append(l_dice) + loss_mask_k = torch.stack(loss_mask_k, dim=1) # [N,K] + loss_dice_k = torch.stack(loss_dice_k, dim=1) # [N,K] + + # combine to pick best proposal per instance + combo = (self.weight_dict["loss_mask"] * loss_mask_k + + self.weight_dict["loss_dice"] * loss_dice_k) + best_idx = combo.argmin(dim=1) # [N] + row = torch.arange(N, device=prd_masks.device) + + # select best proposal losses + loss_mask = loss_mask_k[row, best_idx].mean() + loss_dice = loss_dice_k[row, best_idx].mean() + + # IoU calibration loss + with torch.no_grad(): + probs = torch.sigmoid(prd_masks[row, best_idx]) # [N,H,W] + pred_bin = (probs > 0.5).float() + inter = (gt_masks * pred_bin).sum(dim=(1, 2)) + union = gt_masks.sum(dim=(1, 2)) + pred_bin.sum(dim=(1, 2)) - inter + 1e-6 + true_iou = inter / union # [N] + + if self.supervise_all_iou: + # supervise all proposals + iou_targets = [] + for k in range(K): + probs = torch.sigmoid(prd_masks[:, k]) + pred_bin = (probs > 0.5).float() + inter = (gt_masks * pred_bin).sum(dim=(1, 2)) + union = gt_masks.sum(dim=(1, 2)) + pred_bin.sum(dim=(1, 2)) - inter + 1e-6 + iou_targets.append(inter / union) + iou_targets = torch.stack(iou_targets, dim=1) # [N,K] + if self.iou_use_l1_loss: + loss_iou = F.l1_loss(prd_scores, iou_targets) else: - loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1) + loss_iou = F.mse_loss(prd_scores, iou_targets) else: - loss_mask = loss_multimask - loss_dice = loss_multidice - loss_iou = loss_multiiou - - # backprop focal, dice and iou loss only if obj present - loss_mask = loss_mask * target_obj - loss_dice = loss_dice * target_obj - loss_iou = loss_iou * target_obj - - # sum over batch dimension (note that the losses are already divided by num_objects) - losses["loss_mask"] += loss_mask.sum() - losses["loss_dice"] += loss_dice.sum() - losses["loss_iou"] += loss_iou.sum() - losses["loss_class"] += loss_class - - def reduce_loss(self, losses): - reduced_loss = 0.0 - for loss_key, weight in self.weight_dict.items(): - if loss_key not in losses: - raise ValueError(f"{type(self)} doesn't compute {loss_key}") - if weight != 0: - reduced_loss += losses[loss_key] * weight - - return reduced_loss + score_best = prd_scores[row, best_idx] # [N] + if self.iou_use_l1_loss: + loss_iou = F.l1_loss(score_best, true_iou) + else: + loss_iou = F.mse_loss(score_best, true_iou) + + # weighted sum + total_loss = (self.weight_dict["loss_mask"] * loss_mask + + self.weight_dict["loss_dice"] * loss_dice + + self.weight_dict["loss_iou"] * loss_iou) + + return { + "loss_mask": loss_mask, + "loss_dice": loss_dice, + "loss_iou": loss_iou, + "loss_total": total_loss, + } \ No newline at end of file diff --git a/saber/finetune/train.py b/saber/finetune/train.py index 5f4e2ad..9a0861b 100644 --- a/saber/finetune/train.py +++ b/saber/finetune/train.py @@ -15,17 +15,15 @@ def finetune_sam2( tomo_val: str = None, fib_val: str = None, sam2_cfg: str = 'base', - deviceID: int = 0, num_epochs: int = 1000): """ Finetune SAM2 on tomograms and FIBs """ # Determine device - device = io.get_available_devices(deviceID) - + device = io.get_available_devices(0) (cfg, checkpoint) = pretrained_weights.get_sam2_checkpoint(sam2_cfg) - sam2_model = build_sam2(cfg, checkpoint, device=device, postprocess_mask=False) + sam2_model = build_sam2(cfg, checkpoint, device='cuda', postprocess_mask=False) predictor = SAM2ImagePredictor(sam2_model) # Option 1 : Train the Mask Decoder and Prompt Encoder @@ -33,36 +31,37 @@ def finetune_sam2( predictor.model.sam_prompt_encoder.train(True) # Load data loaders - train_dataset = DataLoader(AutoMaskDataset(tomo_train, fib_train), batch_size=16, shuffle=True, + train_loader = DataLoader(AutoMaskDataset(tomo_train, fib_train), batch_size=16, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate_autoseg) - val_dataset = DataLoader(AutoMaskDataset(tomo_val, fib_val), batch_size=16, shuffle=False, - num_workers=4, pin_memory=True, collate_fn=collate_autoseg) + val_loader = DataLoader(AutoMaskDataset(tomo_val, fib_val), batch_size=16, shuffle=False, + num_workers=4, pin_memory=True, collate_fn=collate_autoseg) if (tomo_val or fib_val) else None # Initialize trainer and train - trainer = SAM2FinetuneTrainer(predictor, train_dataset, val_dataset, device) - trainer.train(train_dataset, val_dataset, num_epochs) - - # Save Results and Model + trainer = SAM2FinetuneTrainer(predictor, train_loader, val_loader) + trainer.train( num_epochs ) @click.command() @sam2_inputs -@click.option("--epochs", type=int, default=10, help="Number of epochs to train for") -@click.option("--train-zarr", type=str, help="Path to train Zarr") -@click.option("--val-zarr", type=str, help="Path to val Zarr") -def finetune(sam2_cfg: str, deviceID: int, num_epochs: int, train_zarr: str, val_zarr: str): +@click.option("--epochs", type=int, default=1000, help="Number of epochs to train for") +@click.option("--fib-train", type=str, help="Path to train Zarr") +@click.option("--fib-val", type=str, help="Path to val Zarr") +@click.option("--tomo-train", type=str, help="Path to train Zarr") +@click.option("--tomo-val", type=str, help="Path to val Zarr") +def finetune(sam2_cfg: str, epochs: int, fib_train: str, fib_val: str, tomo_train: str, tomo_val: str): """ Finetune SAM2 on 3D Volumes. Images from input tomograms and fibs are generated with slabs and slices, respectively. """ print("--------------------------------") print( - f"Fine Tuning SAM2 on {train_zarr} and {val_zarr} for {num_epochs} epochs" + f"Fine Tuning SAM2 on {fib_train} and {fib_val} and {tomo_train} and {tomo_val} for {epochs} epochs" ) print(f"Using SAM2 Config: {sam2_cfg}") - print(f"Using Device: {deviceID}") - print(f"Using Number of Epochs: {num_epochs}") - print(f"Using Train Zarr: {train_zarr}") - print(f"Using Val Zarr: {val_zarr}") + print(f"Using Number of Epochs: {epochs}") + print(f"Using Train Zarr: {fib_train}") + print(f"Using Val Zarr: {fib_val}") + print(f"Using Train Zarr: {tomo_train}") + print(f"Using Val Zarr: {tomo_val}") print("--------------------------------") - finetune_sam2(train_zarr, val_zarr, sam2_cfg, deviceID, num_epochs) \ No newline at end of file + finetune_sam2(tomo_train, fib_train, tomo_val, fib_val, sam2_cfg, epochs) \ No newline at end of file diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py index 2e0a86d..cbad181 100644 --- a/saber/finetune/trainer.py +++ b/saber/finetune/trainer.py @@ -1,33 +1,195 @@ -from lightning.fabric import Fabric +from saber.finetune.losses import MultiMaskIoULoss +from lightning import fabric from tqdm import tqdm -import torch +import torch, os class SAM2FinetuneTrainer: - def __init__(self, model, train_loader, val_loader, device): - self.model = model - self.train_loader = train_loader - self.val_loader = val_loader - self.device = device + def __init__(self, predictor, train_loader, val_loader): + + # Store the predictor + self.predictor = predictor # Two parameter groups for different LRs (optional) params = [ - {"params": [p for p in model.sam_mask_decoder.parameters() if p.requires_grad], + {"params": [p for p in self.predictor.model.sam_mask_decoder.parameters() if p.requires_grad], "lr": 1e-4}, - {"params": [p for p in model.sam_prompt_encoder.parameters() if p.requires_grad], + {"params": [p for p in self.predictor.model.sam_prompt_encoder.parameters() if p.requires_grad], "lr": 5e-5}, ] - self.optimizer = torch.optim.AdamW(params, weight_decay=4e-5) - self.scaler = torch.cuda.amp.GradScaler() + # Initialize the optimizer and dataloaders + self.num_gpus = torch.cuda.device_count() + self.fabric = fabric.Fabric(accelerator="cuda", strategy="ddp", devices=self.num_gpus) + optimizer = torch.optim.AdamW(params, weight_decay=4e-5) + self.predictor.model, self.optimizer = self.fabric.setup(self.predictor.model,optimizer) + + if val_loader is None: + self.train_loader = self.fabric.setup_dataloaders(train_loader) + else: + self.train_loader, self.val_loader = self.fabric.setup_dataloaders(train_loader, val_loader) + + # Initialize the loss function + self.focal_alpha = 0.25 + self.focal_gamma = 2.0 + self.supervise_all_iou = False + self.iou_use_l1_loss = True + + # Initialize the use_boxes flag + self.use_boxes = False + + # Initialize the save directory + self.save_dir = 'results' + os.makedirs(self.save_dir, exist_ok=True) + + @torch.no_grad() + def _stack_image_embeddings_from_predictor(self): + """ + After predictor.set_image_batch(images), gather stacked image embeddings + and high-res features for all B images. + Returns: + image_embeds: [B, C, H', W'] + hr_feats: list[level] of [B, C, H', W'] + """ + # image_embed is a list[len=B] of [C, H', W']; stack to [B, C, H', W'] + image_embeds = torch.stack(list(self.predictor.model._features["image_embed"]), dim=0).to(self.fabric.device) + + # high_res_feats is a list[level], where each level is a list[len=B] of [C, H', W'] + hr = self.predictor.model._features["high_res_feats"] + B = image_embeds.shape[0] + hr_feats = [torch.stack([lvl[b] for b in range(B)], dim=0).to(self.fabric.device) for lvl in hr] + return image_embeds, hr_feats + + def forward_step(self, batch): + """ + Returns: prd_masks [N,K,H,W] logits, prd_scores [N,K], gt_masks [N,H,W], inst_img_ix [N] + """ + images = batch["images"] # list of HxWx3 uint8 or float; predictor handles them + B = len(images) + + # 1) Encode images once + self.predictor.set_image_batch(images) # caches features on predictor + image_embeds_B, hr_feats_B = self._stack_image_embeddings_from_predictor() + + # 2) Flatten instances across batch, move tensors to device + inst_img_ix, gt_all, pts_all, lbl_all, box_all = [], [], [], [], [] + for b in range(B): + for m, p, l, bx in zip(batch["masks"][b], batch["points"][b], batch["labels"][b], batch["boxes"][b]): + inst_img_ix.append(b) + gt_all.append(m.to(self.fabric.device)) + pts_all.append(p.to(self.fabric.device)) + lbl_all.append(l.to(self.fabric.device)) + box_all.append(bx.to(self.fabric.device)) + + N = len(gt_all) + if N == 0: + return None, None, None, None + inst_img_ix = torch.tensor(inst_img_ix, device=self.fabric.device, dtype=torch.long) + + # 3) Pad clicks to (N,P,2) and (N,P) + P = max(p.shape[0] for p in pts_all) + pts_pad = torch.zeros((N, P, 2), device=self.fabric.device, dtype=torch.float32) + lbl_pad = torch.zeros((N, P), device=self.fabric.device, dtype=torch.float32) + for i, (p, l) in enumerate(zip(pts_all, lbl_all)): + pts_pad[i, :p.shape[0]] = p + lbl_pad[i, :l.shape[0]] = l - def train(self, train_loader, val_loader, num_epochs): + # Optional boxes + boxes = torch.stack(box_all, dim=0) if (self.use_boxes and len(box_all) > 0) else None + + # 4) Prompt encoding + mask_input, unnorm_coords, labels, _ = self.predictor._prep_prompts( + input_point=pts_pad, input_label=lbl_pad, box=boxes, mask_logits=None, normalize_coords=True + ) + sparse_embeddings, dense_embeddings = self.predictor.model.sam_prompt_encoder( + points=(unnorm_coords, labels), + boxes=boxes if self.use_boxes else None, + masks=None, + ) + + # 5) Gather per-instance image feats + image_embeds = image_embeds_B[inst_img_ix] # [N,C,H',W'] + hr_feats = [lvl[inst_img_ix] for lvl in hr_feats_B] # list of [N,C,H',W'] + + # 6) Decode + low_res_masks, prd_scores, _, _ = self.predictor.model.sam_mask_decoder( + image_embeddings=image_embeds, + image_pe=self.predictor.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=True, + repeat_image=False, + high_res_features=hr_feats, + ) + + # 7) Upscale + stack GT + prd_masks = self.predictor._transforms.postprocess_masks( + low_res_masks, self.predictor._orig_hw[-1] + ) # [N,K,H,W] logits + gt_masks = torch.stack(gt_all, dim=0) # [N,H,W] + + return prd_masks, prd_scores, gt_masks, inst_img_ix + + @torch.no_grad() + def validate_step(self, batch): + """ + Validate the model on the given batch. + """ + if 'embeddings' in batch: + embeddings = batch['embeddings'] + else: + self.predictor.set_image_batch(batch['images']) + + # Run AMG to get proposals + # proposals = self.predictor + + metrics = {} + return metrics + + def train(self, num_epochs): """ Fine Tune SAM2 on the given data. """ - best_metric_value = -1 + # Initialize the loss function + self.loss_fn = MultiMaskIoULoss( + weight_dict={"loss_mask": 1.0, "loss_dice": 1.0, "loss_iou": 0.05}, + focal_alpha=self.focal_alpha, + focal_gamma=self.focal_gamma, + supervise_all_iou=self.supervise_all_iou, + iou_use_l1_loss=self.iou_use_l1_loss + ) + + self.optimizer.zero_grad() for epoch in tqdm(range(num_epochs)): - - # Reset results for this epoch + + # Initialize the epoch loss epoch_loss_train = 0 epoch_loss_val = 0 + + # Train + self.predictor.model.train() + for batch in self.train_loader: + with self.fabric.autocast(): + out = self.forward_step(batch) + if out[0] is None: + continue + prd_masks, prd_scores, gt_masks, _ = out + losses = self.loss_fn(prd_masks, prd_scores, gt_masks) + self.fabric.backward(losses) + self.optimizer.step() + self.optimizer.zero_grad() + epoch_loss_train += float(losses["loss_total"].detach().cpu()) + + # Validate + self.predictor.model.eval() + with torch.no_grad(): + for batch in self.val_loader: + with self.fabric.autocast(): + out = self.validate_step(batch) + losses = self.loss_fn(out) + + # Print Only on Rank 0 + if self.fabric.is_global_zero: + #Checkpoint + print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss_train/len(self.train_loader)}, Val Loss: {epoch_loss_val/len(self.val_loader)}") + torch.save(self.predictor.model.state_dict(), f"{self.save_dir}/model.pth") \ No newline at end of file From 2590ea4b07de1e16495ffd546d845feb5b15a640 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Sep 2025 18:59:38 +0000 Subject: [PATCH 04/15] near ready draft --- saber/finetune/dataset.py | 128 +++++++++++++++----- saber/finetune/helper.py | 124 +++++++++++++++++-- saber/finetune/metrics.py | 222 +++++++++++++++++++++++++++++++++++ saber/finetune/train.py | 17 ++- saber/finetune/trainer.py | 161 +++++++++++++++++-------- saber/utils/preprocessing.py | 2 +- 6 files changed, 556 insertions(+), 98 deletions(-) create mode 100644 saber/finetune/metrics.py diff --git a/saber/finetune/dataset.py b/saber/finetune/dataset.py index 4849ead..4d71eed 100644 --- a/saber/finetune/dataset.py +++ b/saber/finetune/dataset.py @@ -1,6 +1,7 @@ import saber.finetune.helper as helper from saber.utils import preprocessing from torch.utils.data import Dataset +from tqdm import tqdm import numpy as np import zarr, torch @@ -11,7 +12,8 @@ def __init__(self, transform = None, slabs_per_volume_per_epoch: int = 10, slices_per_fib_per_epoch: int = 5, - slab_thickness: int = 5): + slab_thickness: int = 5, + seed: int = 42): """ Args: tomogram_zarr_path: Path to the tomogram zarr store @@ -29,9 +31,10 @@ def __init__(self, # Grid and Positive Points for AutoMaskGenerator self.points_per_side = 32 - self.points_per_batch = 64 self.min_area = 0.001 - self.k_pos = 2 + self.k_min = 1 + self.k_max = 10 + self.transform = transform # Check if both data types are available if tomogram_zarr_path is None and fib_zarr_path is None: @@ -45,10 +48,16 @@ def __init__(self, if self.has_tomogram: self.tomogram_store = zarr.open(tomogram_zarr_path, mode='r') self.tomogram_keys = [k for k in self.tomogram_store.keys() if not k.startswith('.')] - self.n_tomogram_volumes = len(self.tomogram_keys) self.tomo_shapes = {} - for i, key in enumerate(self.tomogram_keys): - self.tomo_shapes[i] = self.tomogram_store[key]['0'].shape + for i, key in tqdm(enumerate(self.tomogram_keys), + total=len(self.tomogram_keys), desc="Estimating zrange for tomograms"): + try: + self.tomo_shapes[i] = self._estimate_zrange(key) + except Exception as e: + print(f"Error estimating zrange for tomogram {key}: {e}") + # remove key from tomogram_keys + self.tomogram_keys.remove(key) + self.n_tomogram_volumes = len(self.tomogram_keys) else: self.n_tomogram_volumes = 0 self.tomo_shapes = {} @@ -66,12 +75,33 @@ def __init__(self, self.n_fib_volumes = 0 self.fib_shapes = {} self.fib_keys = [] + + # Random seed + self.seed = seed + self._rng = np.random.RandomState(seed) # Resample epoch self.resample_epoch() # Verbose Flag self.verbose = False + + def _estimate_zrange(self, key, band=(0.25, 0.75), threshold=0): + """ + Returns (z_min, z_max) inclusive bounds for valid slab centers + where there is some foreground in the labels. + - threshold: min # of fg pixels to count a slice as non-empty (0 = any) + - band: fraction of Z to consider (lo, hi) + """ + + nz = self.tomogram_store[key]['0'].shape[0] + min_offset, max_offset = int(nz * band[0]), int(nz * band[1]) + mask = self.tomogram_store[key]['labels/0'][min_offset:max_offset,] + vals = mask.sum(axis=(1,2)) + vals = np.nonzero(vals)[0] + max_val = vals.max() - self.slab_thickness // 2 + min_offset + min_val = vals.min() + self.slab_thickness // 2 + min_offset + return int(min_val), int(max_val) def resample_epoch(self): """ Generate new random samples for this epoch """ @@ -80,25 +110,24 @@ def resample_epoch(self): # Sample random slabs from each tomogram if self.has_tomogram: + print(f"Re-Sampling {self.slabs_per_volume_per_epoch} slabs from {self.n_tomogram_volumes} tomograms") for vol_idx in range(self.n_tomogram_volumes): - volume_shape = self.tomo_shapes[vol_idx] - # Valid range for center of slab - valid_z_min = int(volume_shape[0] / 4) - valid_z_max = int(volume_shape[0] * (3 / 4)) - - if valid_z_max > valid_z_min: - z_positions = np.random.randint( - valid_z_min, - valid_z_max, - size=self.slabs_per_volume_per_epoch - ) - - for z_pos in z_positions: - self.tomogram_samples.append((vol_idx, z_pos)) + + # Sample random z positions from this tomogram volume + z_min, z_max = self.tomo_shapes[vol_idx] + z_positions = np.random.randint( + z_min, + z_max, + size=self.slabs_per_volume_per_epoch + ) + # Add to samples + for z_pos in z_positions: + self.tomogram_samples.append((vol_idx, z_pos)) np.random.shuffle(self.tomogram_samples) # Shuffle samples # Sample random slices from each FIB volume if self.has_fib: + print(f"Re-Sampling {self.slices_per_fib_per_epoch} slices from {self.n_fib_volumes} FIB volumes") for fib_idx in range(self.n_fib_volumes): fib_shape = self.fib_shapes[fib_idx] # Sample random z positions from this FIB volume @@ -154,7 +183,7 @@ def _get_fib_item(self, idx): key = self.fib_keys[fib_idx] # Load FIB image and segmentation - image = self.fib_store[key]['0'][z_pos,] + image = self.fib_store[key]['0'][z_pos,].astype(np.float32) image_2d = preprocessing.proprocess(image) seg_2d = self.fib_store[key]['labels/0'][z_pos,] @@ -169,6 +198,50 @@ def _gen_grid_points(self, h: int, w: int) -> np.ndarray: xx, yy = np.meshgrid(xs, ys) return np.stack([xx.ravel(), yy.ravel()], axis=1) # (G,2) as (x,y) + def _sample_points_in_mask( + self, + comp: np.ndarray, + grid_points: np.ndarray, + jitter_px: float = 0.0, + shape: tuple[int, int] = None, + ) -> np.ndarray: + """ + Pick clicks from a regular grid, restricted to those that land inside `comp`. + Optionally jitter each kept grid point by up to ±jitter_px. + Returns float32 array [K,2] in (x,y). + """ + + if comp.sum() == 0: + return np.zeros((0, 2), dtype=np.float32) + + h, w = shape + + # round grid coords to nearest pixel for mask lookup + gx = np.clip(np.rint(grid_points[:, 0]).astype(int), 0, w - 1) + gy = np.clip(np.rint(grid_points[:, 1]).astype(int), 0, h - 1) + + inside = comp[gy, gx] > 0 + cand = grid_points[inside] # (M,2) subset of the grid + + if cand.shape[0] == 0: + return np.zeros((0, 2), dtype=np.float32) + + k = int(self._rng.randint(self.k_min, self.k_max + 1)) + k = min(k, cand.shape[0]) + + idx = self._rng.choice(cand.shape[0], size=k, replace=False) + pts = cand[idx].astype(np.float32) + + # optional jitter to avoid perfectly regular patterns + if jitter_px > 0: + jitter = self._rng.uniform(-jitter_px, jitter_px, size=pts.shape).astype(np.float32) + pts = pts + jitter + # keep inside image bounds + pts[:, 0] = np.clip(pts[:, 0], 0, w - 1) + pts[:, 1] = np.clip(pts[:, 1], 0, h - 1) + + return pts + def _package_image_item(self, image_2d: np.ndarray, segmentation: np.ndarray): @@ -186,6 +259,11 @@ def _package_image_item(self, "boxes": list[4] float32 (x0,y0,x1,y1) } """ + + # Apply transforms to image and segmentation + if self.transform: + data = self.transform({'image': image_2d, 'masks': segmentation}) + image_2d, segmentation = data['image'], data['masks'] h, w = segmentation.shape min_pixels = 0 @@ -205,7 +283,8 @@ def _package_image_item(self, continue # sample clicks from this component (NOT the full instance) - pts = helper.sample_positive_points(comp, k=self.k_pos) + # pts = helper.sample_positive_points(comp, k=self.k_pos) + pts = self._sample_points_in_mask(comp, grid_points, shape=(h, w)) if pts.shape[0] == 0: continue @@ -222,11 +301,6 @@ def _package_image_item(self, labels_t = [torch.from_numpy(np.ones((1,), dtype=np.float32))] boxes_t = [torch.from_numpy(np.array([0, 0, 1, 1], dtype=np.float32))] - # Apply transforms - if self.transform: - data = self.transform({'image': image_2d, 'masks': masks_t}) - image_2d, masks_t = data['image'], data['masks'] - return { "image": image_2d, # HxWx3 uint8 "masks": masks_t, # list[H x W] float32 in {0,1} diff --git a/saber/finetune/helper.py b/saber/finetune/helper.py index e3e8e5e..5a994c2 100644 --- a/saber/finetune/helper.py +++ b/saber/finetune/helper.py @@ -2,7 +2,7 @@ import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np -import cv2 +import cv2, torch def mask_to_box(mask: np.ndarray) -> np.ndarray | None: """xyxy box from a binary mask (H,W) in {0,1}.""" @@ -39,16 +39,103 @@ def components_for_id(seg, iid: int, min_pixels: int): comps.append(comp) return comps -def collate_autoseg(batch): - # batch: list of dicts from _package_image_item + +def collate_autoseg(batch, max_res: int = 1024): + """ + Aspect-preserving resize to fit within max_res, then pad each sample + to the batch's (H_max, W_max). Keep ragged structures (points/labels/boxes) + as lists per instance. + """ + processed = [_resize_one(sample, max_res) for sample in batch] + + # Common padded size for this batch + H_max = max(s["image"].shape[0] for s in processed) + W_max = max(s["image"].shape[1] for s in processed) + + for s in processed: + h, w = s["image"].shape[:2] + + # pad image (top-left anchoring) + pad_img = np.zeros((H_max, W_max, 3), dtype=s["image"].dtype) + pad_img[:h, :w] = s["image"] + s["image"] = pad_img + + # pad masks + padded_masks = [] + for m in s["masks"]: + pm = np.zeros((H_max, W_max), dtype=np.uint8) + pm[:h, :w] = m + padded_masks.append(pm) + s["masks"] = padded_masks + + # NOTE: points/boxes coords don't change with top-left padding + + # Return exactly what your trainer expects: lists of per-sample items, + # and for ragged things, lists of per-instance tensors. + return { + "images": [s["image"] for s in processed], # list of HxWx3 uint8 (predictor handles numpy) + "masks": [torch.from_numpy(np.stack(s["masks"])) for s in processed], # list of [M,H,W] + "points": [ [torch.from_numpy(p) for p in s["points"]] for s in processed], # list of list[[Pi,2]] + "labels": [ [torch.as_tensor(l) for l in s["labels"]] for s in processed], # list of list[[Pi]] + "boxes": [ [torch.from_numpy(b) for b in s["boxes"]] for s in processed], # list of list[[4]] + } + +def _resize_one(s, max_res: int): + """ + Resize one sample with aspect preserved. Scale coords per instance. + Keeps ragged structures as lists. + """ + img = s["image"] + masks = s["masks"] # list of [H,W] + points = s["points"] # list of [Pi,2] + labels = s["labels"] # list of [Pi] + boxes = s["boxes"] # list of [4] + + H, W = img.shape[:2] + r = min(max_res / max(H, W), 1.0) # cap longest side; don't upscale + newH, newW = int(round(H * r)), int(round(W * r)) + + if (newH, newW) != (H, W): + img = cv2.resize(img, (newW, newH), interpolation=cv2.INTER_LINEAR) + masks = [cv2.resize(m.astype(np.uint8), (newW, newH), interpolation=cv2.INTER_NEAREST) for m in masks] + pts = [np.asarray(p, dtype=np.float32) * r for p in points] # scale each instance + bxs = [np.asarray(b, dtype=np.float32) * r for b in boxes] + else: + pts = [np.asarray(p, dtype=np.float32) for p in points] + bxs = [np.asarray(b, dtype=np.float32) for b in boxes] + + # labels stay as-is per instance + lbls = [np.asarray(l) for l in labels] + return { - "images": [b["image"] for b in batch], # list of HxWx3 uint8 - "masks": [b["masks"] for b in batch], # list of list[H x W] - "points": [b["points"] for b in batch], # list of list[#p x 2] - "labels": [b["labels"] for b in batch], # list of list[#p] - "boxes": [b["boxes"] for b in batch], # list of list[4] + "image": img, + "masks": masks, + "points": pts, + "labels": lbls, + "boxes": bxs, } +# def collate_autoseg(batch, max_res: int = 768): +# # batch: list of dicts from _package_image_item + +# output = { "images": [],"masks": [],"points": [],"labels": [],"boxes": []} +# for b in batch: +# results = _resize_inputs(b, max_res) +# for k, v in results.items(): +# output[k].append(v) + +# return output + +# def _resize_inputs(b, max_res): +# h, w = b["image"].shape[:2] +# if h > max_res or w > max_res: +# b["image"] = cv2.resize(b["image"], (max_res, max_res)) +# b["masks"] = [cv2.resize(m, (max_res, max_res), interpolation=cv2.INTER_NEAREST) for m in b["masks"]] +# b["points"] = [p * max_res / h for p in b["points"]] +# b["labels"] = [l * max_res / h for l in b["labels"]] +# b["boxes"] = [b * max_res / h for b in b["boxes"]] +# return b + def _to_numpy_mask_stack(masks): """ Accepts list/tuple of tensors or np arrays shaped [H,W]; @@ -117,4 +204,23 @@ def visualize_item_with_points(image, masks, points, boxes=None, ax.set_title(title) ax.axis("off") plt.tight_layout() - plt.show() \ No newline at end of file + # plt.show() + +# def sample_points_in_mask(mask: np.ndarray, k_min: int = 1, k_max: int = 10) -> np.ndarray: +# """ +# Uniformly sample a random number of clicks in [k_min, k_max] from a binary component mask. +# Always inside the mask; capped by the number of foreground pixels. +# Returns float32 array of shape [K, 2] in (x, y). +# """ +# ys, xs = np.nonzero(mask) # foreground coordinates +# n = xs.size +# if n == 0: +# return np.zeros((0, 2), dtype=np.float32) + +# k = int(_rng.randint(k_min, k_max + 1)) # inclusive range +# k = min(k, n) # cap by available pixels + +# # sample without replacement so clicks are distinct +# idx = self._rng.choice(n, size=k, replace=False) +# pts = np.stack([xs[idx], ys[idx]], axis=1).astype(np.float32) # (x, y) +# return pts \ No newline at end of file diff --git a/saber/finetune/metrics.py b/saber/finetune/metrics.py new file mode 100644 index 0000000..dc2dbbc --- /dev/null +++ b/saber/finetune/metrics.py @@ -0,0 +1,222 @@ +import numpy as np +import torch +import torch.nn.functional as F + +# Adjust this import to your package layout +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator + + +# --------------------- IoU / metric helpers --------------------- + +def _mask_iou(a, b, eps=1e-6): + """ + a: [Na,H,W] {0,1}; b: [Nb,H,W] {0,1} + returns IoU matrix [Na,Nb] + """ + if a.numel() == 0 or b.numel() == 0: + return torch.zeros((a.shape[0], b.shape[0]), device=a.device) + a = a.float() + b = b.float() + inter = torch.einsum("nhw,mhw->nm", a, b) + ua = a.sum(dim=(1,2)).unsqueeze(1) + ub = b.sum(dim=(1,2)).unsqueeze(0) + union = ua + ub - inter + eps + return inter / union + + +def _abiou(proposals, gts): + """ + Average Best IoU (higher is better). + proposals: [Np,H,W] {0,1}, gts: [Ng,H,W] {0,1} + """ + if gts.numel() == 0 and proposals.numel() == 0: + return torch.tensor(1.0, device=proposals.device if proposals.is_cuda else gts.device) + if gts.numel() == 0: + return torch.tensor(0.0, device=proposals.device) + if proposals.numel() == 0: + return torch.tensor(0.0, device=gts.device) + iou = _mask_iou(gts, proposals) # [Ng,Np] + best = iou.max(dim=1).values # [Ng] + return best.mean() + + +def _ap_at_threshold(proposals, scores, gts, thr=0.5): + """ + Greedy one-to-one matching by descending score at a single IoU threshold. + proposals: [Np,H,W] {0,1} + scores: [Np] (higher is better) + gts: [Ng,H,W] {0,1} + """ + # Degenerate cases + if gts.numel() == 0 and proposals.numel() == 0: + return torch.tensor(1.0, device=scores.device) + if proposals.numel() == 0: + return torch.tensor(0.0, device=scores.device) + + order = scores.argsort(descending=True) + props = proposals[order] + + matched_gt = torch.zeros((gts.shape[0],), dtype=torch.bool, device=gts.device) + tp, fp = [], [] + for i in range(props.shape[0]): + if gts.numel() == 0: + fp.append(1); tp.append(0); continue + ious = _mask_iou(props[i:i+1], gts)[0] # [Ng] + j = torch.argmax(ious) + if ious[j] >= thr and not matched_gt[j]: + matched_gt[j] = True + tp.append(1); fp.append(0) + else: + tp.append(0); fp.append(1) + + tp = torch.tensor(tp, device=scores.device).cumsum(0) + fp = torch.tensor(fp, device=scores.device).cumsum(0) + precision = tp / (tp + fp).clamp(min=1) + recall = tp / max(gts.shape[0], 1) + + # Precision envelope + trapezoidal integral over recall + mrec, idx = torch.sort(recall) + mpre = precision[idx] + for k in range(mpre.shape[0] - 2, -1, -1): + mpre[k] = torch.maximum(mpre[k], mpre[k+1]) + return torch.trapz(mpre, mrec) + + +def _map(proposals, scores, gts, thresholds=(0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95)): + """ + Mean AP across IoU thresholds (COCO-style, class-agnostic). + """ + aps = [ _ap_at_threshold(proposals, scores, gts, thr=t) for t in thresholds ] + return torch.stack(aps).mean() + + +# --------------------- Main evaluator --------------------- + +@torch.no_grad() +def automask_metrics( + sam2_model_or_predictor, + images, + gt_masks_per_image, + *, + amg_kwargs=None, + top_k=None, # optionally cap #proposals/image after AMG (for speed) + compute_map=True, # if False, only ABIoU and AP@0.5 + map_thresholds=(0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95), + device=None, +): + """ + Run SAM2AutomaticMaskGenerator on each image and compute ABIoU, AP@0.5 (and mAP if requested). + + Args + ---- + sam2_model_or_predictor : the fine-tuned SAM2 object. If you passed a predictor wrapper, + we'll try to hand its `.model` to SAM2AutomaticMaskGenerator. + images : list of images; each can be: + - HxWx3 uint8 NumPy (preferred for AMG), or + - torch.Tensor (H,W) or (H,W,3) float in [0,1] or [0,255] + gt_masks_per_image : list[list[H x W]]; elements can be NumPy or torch; non-zero = foreground + amg_kwargs : dict of params forwarded to SAM2AutomaticMaskGenerator(...) + e.g. {'points_per_side': 16, 'points_per_batch': 64, 'pred_iou_thresh': 0.7, ...} + top_k : optional int to keep only top-K proposals per image by score after AMG + compute_map : whether to compute mAP over multiple IoU thresholds + map_thresholds : tuple of IoU thresholds + device : torch device for metrics (defaults to 'cuda' if available) + + Returns + ------- + summary : dict with aggregated metrics: + { + 'ABIoU': float, + 'AP50': float, + 'mAP': float or None, + 'num_images': int, + 'per_image': [ {'ABIoU':..., 'AP50':..., 'mAP':... or None, 'num_props': int, 'num_gt': int}, ... ] + } + """ + # local alias to avoid confusion with user's np + _amg_kwargs = dict( + points_per_side=32, + points_per_batch=64, + pred_iou_thresh=0.7, + stability_score_thresh=0.92, + stability_score_offset=0.7, + crop_n_layers=1, + crop_n_points_downscale_factor=2, + box_nms_thresh=0.7, + use_m2m=True, + multimask_output=True, + ) + if amg_kwargs: + _amg_kwargs.update(amg_kwargs) + + # Figure out what to pass into the generator: underlying model vs predictor + model_for_amg = getattr(sam2_model_or_predictor, "model", sam2_model_or_predictor) + + mask_generator = SAM2AutomaticMaskGenerator(model=model_for_amg, **_amg_kwargs) + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + per_image = [] + abiou_vals, ap50_vals, map_vals = [], [], [] + + for img, gt_list in zip(images, gt_masks_per_image): + + # ---------- run AMG ---------- + props = mask_generator.generate(img) # list of dicts with 'segmentation' and 'predicted_iou'/'score' + + H, W, _ = img.shape + if len(props) == 0: + prop_masks = torch.zeros((0, H, W), dtype=torch.uint8, device=device) + prop_scores = torch.zeros((0,), dtype=torch.float32, device=device) + else: + # sort by predicted IoU score, keep top_k if requested + def _score(d): + return float(d.get("predicted_iou", d.get("score", 0.0))) + props = sorted(props, key=_score, reverse=True) + if top_k is not None and top_k > 0: + props = props[:top_k] + + masks_np = [(p["segmentation"] > 0).astype(np.uint8) for p in props] + scores_np = [_score(p) for p in props] + prop_masks = torch.from_numpy(np.stack(masks_np, axis=0)).to(device=device) + prop_scores = torch.tensor(scores_np, device=device, dtype=torch.float32) + + # ---------- stack GTs ---------- + if len(gt_list) == 0: + gt_masks = torch.zeros((0, H, W), dtype=torch.uint8, device=device) + else: + gts = [] + for g in gt_list: + g_np = g.detach().cpu().numpy() if isinstance(g, torch.Tensor) else g + gts.append((g_np > 0).astype(np.uint8)) + gt_masks = torch.from_numpy(np.stack(gts, axis=0)).to(device=device) + + # ---------- metrics ---------- + abiou = _abiou(prop_masks, gt_masks) + if compute_map: + mAP = _map(prop_masks, prop_scores, gt_masks, thresholds=map_thresholds) + else: + mAP = None + + per_image.append({ + "ABIoU": float(abiou.detach().cpu()), + "mAP": (float(mAP.detach().cpu()) if mAP is not None else None), + "num_props": int(prop_masks.shape[0]), + "num_gt": int(gt_masks.shape[0]), + }) + + abiou_vals.append(abiou) + if compute_map: + map_vals.append(mAP) + + # aggregate + ABIoU = torch.stack(abiou_vals).mean().item() if abiou_vals else 0.0 + mAP = (torch.stack(map_vals).mean().item() if compute_map and map_vals else None) + + return { + "ABIoU": ABIoU, + "mAP": mAP, + "num_images": len(per_image), + "per_image": per_image, + } diff --git a/saber/finetune/train.py b/saber/finetune/train.py index 9a0861b..f55a6c1 100644 --- a/saber/finetune/train.py +++ b/saber/finetune/train.py @@ -21,23 +21,22 @@ def finetune_sam2( """ # Determine device - device = io.get_available_devices(0) (cfg, checkpoint) = pretrained_weights.get_sam2_checkpoint(sam2_cfg) sam2_model = build_sam2(cfg, checkpoint, device='cuda', postprocess_mask=False) predictor = SAM2ImagePredictor(sam2_model) - + # Option 1 : Train the Mask Decoder and Prompt Encoder predictor.model.sam_mask_decoder.train(True) predictor.model.sam_prompt_encoder.train(True) # Load data loaders - train_loader = DataLoader(AutoMaskDataset(tomo_train, fib_train), batch_size=16, shuffle=True, - num_workers=4, pin_memory=True, collate_fn=collate_autoseg) - val_loader = DataLoader(AutoMaskDataset(tomo_val, fib_val), batch_size=16, shuffle=False, - num_workers=4, pin_memory=True, collate_fn=collate_autoseg) if (tomo_val or fib_val) else None + train_loader = DataLoader( AutoMaskDataset(tomo_train, fib_train), batch_size=16, shuffle=True, + num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) + val_loader = DataLoader( AutoMaskDataset(tomo_val, fib_val), batch_size=16, shuffle=False, + num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) if (tomo_val or fib_val) else None # Initialize trainer and train - trainer = SAM2FinetuneTrainer(predictor, train_loader, val_loader) + trainer = SAM2FinetuneTrainer( predictor, train_loader, val_loader ) trainer.train( num_epochs ) @click.command() @@ -58,10 +57,10 @@ def finetune(sam2_cfg: str, epochs: int, fib_train: str, fib_val: str, tomo_trai ) print(f"Using SAM2 Config: {sam2_cfg}") print(f"Using Number of Epochs: {epochs}") - print(f"Using Train Zarr: {fib_train}") - print(f"Using Val Zarr: {fib_val}") print(f"Using Train Zarr: {tomo_train}") print(f"Using Val Zarr: {tomo_val}") + print(f"Using Train Zarr: {fib_train}") + print(f"Using Val Zarr: {fib_val}") print("--------------------------------") finetune_sam2(tomo_train, fib_train, tomo_val, fib_val, sam2_cfg, epochs) \ No newline at end of file diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py index cbad181..864307a 100644 --- a/saber/finetune/trainer.py +++ b/saber/finetune/trainer.py @@ -1,3 +1,4 @@ +from saber.finetune.metrics import automask_metrics from saber.finetune.losses import MultiMaskIoULoss from lightning import fabric from tqdm import tqdm @@ -18,21 +19,33 @@ def __init__(self, predictor, train_loader, val_loader): ] # Initialize the optimizer and dataloaders - self.num_gpus = torch.cuda.device_count() - self.fabric = fabric.Fabric(accelerator="cuda", strategy="ddp", devices=self.num_gpus) + self.num_gpus = torch.cuda.device_count() optimizer = torch.optim.AdamW(params, weight_decay=4e-5) - self.predictor.model, self.optimizer = self.fabric.setup(self.predictor.model,optimizer) - - if val_loader is None: - self.train_loader = self.fabric.setup_dataloaders(train_loader) + if self.num_gpus > 1: + self.fabric = fabric.Fabric(accelerator="cuda", strategy="ddp", devices=self.num_gpus) + self.fabric.launch() + self.predictor.model, self.optimizer = self.fabric.setup(self.predictor.model,optimizer) + self.autocast = self.fabric.autocast + self.use_fabric = True else: + self.optimizer = optimizer + self.use_fabric = False + self.autocast = torch.cuda.amp.autocast + self.device = next(self.predictor.model.parameters()).device + + if val_loader is None and self.use_fabric: + self.train_loader = self.fabric.setup_dataloaders(train_loader) + elif self.use_fabric and val_loader is not None: self.train_loader, self.val_loader = self.fabric.setup_dataloaders(train_loader, val_loader) + else: + self.train_loader, self.val_loader = train_loader, val_loader # Initialize the loss function self.focal_alpha = 0.25 self.focal_gamma = 2.0 self.supervise_all_iou = False self.iou_use_l1_loss = True + self.predict_multimask = True # Initialize the use_boxes flag self.use_boxes = False @@ -51,12 +64,12 @@ def _stack_image_embeddings_from_predictor(self): hr_feats: list[level] of [B, C, H', W'] """ # image_embed is a list[len=B] of [C, H', W']; stack to [B, C, H', W'] - image_embeds = torch.stack(list(self.predictor.model._features["image_embed"]), dim=0).to(self.fabric.device) + image_embeds = torch.stack(list(self.predictor._features["image_embed"]), dim=0).to(self.device) # high_res_feats is a list[level], where each level is a list[len=B] of [C, H', W'] - hr = self.predictor.model._features["high_res_feats"] + hr = self.predictor._features["high_res_feats"] B = image_embeds.shape[0] - hr_feats = [torch.stack([lvl[b] for b in range(B)], dim=0).to(self.fabric.device) for lvl in hr] + hr_feats = [torch.stack([lvl[b] for b in range(B)], dim=0).to(self.device) for lvl in hr] return image_embeds, hr_feats def forward_step(self, batch): @@ -75,20 +88,20 @@ def forward_step(self, batch): for b in range(B): for m, p, l, bx in zip(batch["masks"][b], batch["points"][b], batch["labels"][b], batch["boxes"][b]): inst_img_ix.append(b) - gt_all.append(m.to(self.fabric.device)) - pts_all.append(p.to(self.fabric.device)) - lbl_all.append(l.to(self.fabric.device)) - box_all.append(bx.to(self.fabric.device)) + gt_all.append(m.to(self.device)) + pts_all.append(p.to(self.device)) + lbl_all.append(l.to(self.device)) + box_all.append(bx.to(self.device)) N = len(gt_all) if N == 0: return None, None, None, None - inst_img_ix = torch.tensor(inst_img_ix, device=self.fabric.device, dtype=torch.long) + inst_img_ix = torch.tensor(inst_img_ix, device=self.device, dtype=torch.long) # 3) Pad clicks to (N,P,2) and (N,P) P = max(p.shape[0] for p in pts_all) - pts_pad = torch.zeros((N, P, 2), device=self.fabric.device, dtype=torch.float32) - lbl_pad = torch.zeros((N, P), device=self.fabric.device, dtype=torch.float32) + pts_pad = torch.zeros((N, P, 2), device=self.device, dtype=torch.float32) + lbl_pad = torch.zeros((N, P), device=self.device, dtype=torch.float32) for i, (p, l) in enumerate(zip(pts_all, lbl_all)): pts_pad[i, :p.shape[0]] = p lbl_pad[i, :l.shape[0]] = l @@ -98,7 +111,9 @@ def forward_step(self, batch): # 4) Prompt encoding mask_input, unnorm_coords, labels, _ = self.predictor._prep_prompts( - input_point=pts_pad, input_label=lbl_pad, box=boxes, mask_logits=None, normalize_coords=True + pts_pad, lbl_pad, + box=boxes, mask_logits=None, + normalize_coords=True ) sparse_embeddings, dense_embeddings = self.predictor.model.sam_prompt_encoder( points=(unnorm_coords, labels), @@ -116,7 +131,7 @@ def forward_step(self, batch): image_pe=self.predictor.model.sam_prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, - multimask_output=True, + multimask_output=self.predict_multimask, repeat_image=False, high_res_features=hr_feats, ) @@ -125,27 +140,62 @@ def forward_step(self, batch): prd_masks = self.predictor._transforms.postprocess_masks( low_res_masks, self.predictor._orig_hw[-1] ) # [N,K,H,W] logits - gt_masks = torch.stack(gt_all, dim=0) # [N,H,W] + gt_masks = torch.stack(gt_all, dim=0).float() # [N,H,W] return prd_masks, prd_scores, gt_masks, inst_img_ix @torch.no_grad() - def validate_step(self, batch): + def validate_step(self): """ Validate the model on the given batch. """ - if 'embeddings' in batch: - embeddings = batch['embeddings'] - else: - self.predictor.set_image_batch(batch['images']) - # Run AMG to get proposals - # proposals = self.predictor - - metrics = {} - return metrics - - def train(self, num_epochs): + self.predictor.model.eval() + + # Local accumulators (weighted by number of images in each call) + abiou_sum = torch.tensor(0.0, device=self.device) + map_sum = torch.tensor(0.0, device=self.device) + n_imgs = torch.tensor(0.0, device=self.device) + + # Each rank iterates only its shard (Fabric sets DistributedSampler for you) + for batch in self.val_loader: + # Compute metrics on THIS batch only (keeps memory small & parallel) + m = automask_metrics( + self.predictor, # predictor or predictor.model (your function supports either) + batch["images"], # list[H×W×3] or list[H×W] + batch["masks"], # list[list[H×W]] + amg_kwargs={"points_per_side": 16, "pred_iou_thresh": 0.7, "crop_n_layers": 1}, + top_k=100, + compute_map=True, + device=self.device + ) + + # Weight by number of images so we can average correctly later + num = float(m["num_images"]) + abiou_sum += torch.tensor(m["ABIoU"] * num, device=self.device) + if m["mAP"] is not None: + map_sum += torch.tensor(m["mAP"] * num, device=self.device) + n_imgs += torch.tensor(num, device=self.device) + + # Global reduction (sum across all ranks) + if self.use_fabric: + abiou_sum = self.fabric.all_reduce(abiou_sum, reduce_op="sum") + map_sum = self.fabric.all_reduce(map_sum, reduce_op="sum") + n_imgs = self.fabric.all_reduce(n_imgs, reduce_op="sum") + else: + abiou_sum = abiou_sum.sum() + map_sum = map_sum.sum() + n_imgs = n_imgs.sum() + + # Avoid divide-by-zero + denom = max(n_imgs.item(), 1.0) + return { + "ABIoU": (abiou_sum / denom).item(), + "mAP": (map_sum / denom).item(), + "num_images": int(denom), + } + + def train(self, num_epochs, best_metric = 'mAP'): """ Fine Tune SAM2 on the given data. """ @@ -159,37 +209,44 @@ def train(self, num_epochs): iou_use_l1_loss=self.iou_use_l1_loss ) + best_metric_value = float('-inf') self.optimizer.zero_grad() - for epoch in tqdm(range(num_epochs)): - - # Initialize the epoch loss - epoch_loss_train = 0 - epoch_loss_val = 0 - + for epoch in tqdm(range(num_epochs), desc="Training", unit="epoch"): # Train + epoch_loss_train = 0 self.predictor.model.train() + self.train_loader.dataset.resample_epoch() for batch in self.train_loader: - with self.fabric.autocast(): - out = self.forward_step(batch) - if out[0] is None: - continue - prd_masks, prd_scores, gt_masks, _ = out + out = self.forward_step(batch) + if out[0] is None: + continue + prd_masks, prd_scores, gt_masks, _ = out + with self.autocast(): losses = self.loss_fn(prd_masks, prd_scores, gt_masks) - self.fabric.backward(losses) + if self.use_fabric: + self.fabric.backward(losses['loss_total']) + else: + losses['loss_total'].backward() self.optimizer.step() self.optimizer.zero_grad() epoch_loss_train += float(losses["loss_total"].detach().cpu()) + import pdb; pdb.set_trace() + # Validate - self.predictor.model.eval() - with torch.no_grad(): - for batch in self.val_loader: - with self.fabric.autocast(): - out = self.validate_step(batch) - losses = self.loss_fn(out) + metrics = self.validate_step() + + import pdb; pdb.set_trace() # Print Only on Rank 0 if self.fabric.is_global_zero: - #Checkpoint - print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss_train/len(self.train_loader)}, Val Loss: {epoch_loss_val/len(self.val_loader)}") - torch.save(self.predictor.model.state_dict(), f"{self.save_dir}/model.pth") \ No newline at end of file + print( + f"Epoch {epoch+1}/{num_epochs} " + f"Loss={epoch_loss_train/len(self.train_loader):.5f} " + f"mAP={metrics['mAP']:.4f} - ABIoU={metrics['ABIoU']:.4f} " + ) + + if metrics[best_metric] > best_metric_value: + best_metric_value = metrics[best_metric] + torch.save(self.predictor.model.state_dict(), f"{self.save_dir}/best_model.pth") + print(f"Best {best_metric} saved!") \ No newline at end of file diff --git a/saber/utils/preprocessing.py b/saber/utils/preprocessing.py index 82697ff..1497cf6 100644 --- a/saber/utils/preprocessing.py +++ b/saber/utils/preprocessing.py @@ -40,7 +40,7 @@ def proprocess(image: np.ndarray, std_cutoff=3, rgb=False): """ image = contrast(image, std_cutoff=std_cutoff) image = normalize(image, rgb=rgb) - image = np.repeat(image[..., None], 3, axis=2) if rgb else image + image = image if rgb else np.repeat(image[..., None], 3, axis=2) return image def project_tomogram(vol, zSlice = None, deltaZ = None): From b7bd11ff55e4b2616056f3e0dda0c079a3dd8e28 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Sep 2025 21:24:02 +0000 Subject: [PATCH 05/15] ready to start testing soon --- saber/finetune/dataset.py | 4 ++-- saber/finetune/metrics.py | 13 ++++++------- saber/finetune/train.py | 17 ++++++++++------- saber/finetune/trainer.py | 6 +----- 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/saber/finetune/dataset.py b/saber/finetune/dataset.py index 4d71eed..c90fb55 100644 --- a/saber/finetune/dataset.py +++ b/saber/finetune/dataset.py @@ -32,8 +32,8 @@ def __init__(self, # Grid and Positive Points for AutoMaskGenerator self.points_per_side = 32 self.min_area = 0.001 - self.k_min = 1 - self.k_max = 10 + self.k_min = 50 + self.k_max = 100 self.transform = transform # Check if both data types are available diff --git a/saber/finetune/metrics.py b/saber/finetune/metrics.py index dc2dbbc..13bf70c 100644 --- a/saber/finetune/metrics.py +++ b/saber/finetune/metrics.py @@ -1,9 +1,8 @@ +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +import torch.nn.functional as F import numpy as np import torch -import torch.nn.functional as F -# Adjust this import to your package layout -from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator # --------------------- IoU / metric helpers --------------------- @@ -136,15 +135,15 @@ def automask_metrics( # local alias to avoid confusion with user's np _amg_kwargs = dict( points_per_side=32, - points_per_batch=64, + points_per_batch=128, pred_iou_thresh=0.7, stability_score_thresh=0.92, stability_score_offset=0.7, crop_n_layers=1, crop_n_points_downscale_factor=2, box_nms_thresh=0.7, - use_m2m=True, - multimask_output=True, + use_m2m=False, + multimask_output=False, ) if amg_kwargs: _amg_kwargs.update(amg_kwargs) @@ -158,7 +157,7 @@ def automask_metrics( device = torch.device("cuda" if torch.cuda.is_available() else "cpu") per_image = [] - abiou_vals, ap50_vals, map_vals = [], [], [] + abiou_vals, map_vals = [], [] for img, gt_list in zip(images, gt_masks_per_image): diff --git a/saber/finetune/train.py b/saber/finetune/train.py index f55a6c1..021407c 100644 --- a/saber/finetune/train.py +++ b/saber/finetune/train.py @@ -15,7 +15,8 @@ def finetune_sam2( tomo_val: str = None, fib_val: str = None, sam2_cfg: str = 'base', - num_epochs: int = 1000): + num_epochs: int = 1000, + batch_size: int = 16): """ Finetune SAM2 on tomograms and FIBs """ @@ -30,9 +31,9 @@ def finetune_sam2( predictor.model.sam_prompt_encoder.train(True) # Load data loaders - train_loader = DataLoader( AutoMaskDataset(tomo_train, fib_train), batch_size=16, shuffle=True, + train_loader = DataLoader( AutoMaskDataset(tomo_train, fib_train), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) - val_loader = DataLoader( AutoMaskDataset(tomo_val, fib_val), batch_size=16, shuffle=False, + val_loader = DataLoader( AutoMaskDataset(tomo_val, fib_val), batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) if (tomo_val or fib_val) else None # Initialize trainer and train @@ -41,12 +42,13 @@ def finetune_sam2( @click.command() @sam2_inputs -@click.option("--epochs", type=int, default=1000, help="Number of epochs to train for") @click.option("--fib-train", type=str, help="Path to train Zarr") @click.option("--fib-val", type=str, help="Path to val Zarr") @click.option("--tomo-train", type=str, help="Path to train Zarr") @click.option("--tomo-val", type=str, help="Path to val Zarr") -def finetune(sam2_cfg: str, epochs: int, fib_train: str, fib_val: str, tomo_train: str, tomo_val: str): +@click.option("--epochs", type=int, default=1000, help="Number of epochs to train for") +@click.option('--batch-size', type=int, default=16, help="Batch size") +def finetune(sam2_cfg: str, epochs: int, fib_train: str, fib_val: str, tomo_train: str, tomo_val: str, batch_size: int): """ Finetune SAM2 on 3D Volumes. Images from input tomograms and fibs are generated with slabs and slices, respectively. """ @@ -56,11 +58,12 @@ def finetune(sam2_cfg: str, epochs: int, fib_train: str, fib_val: str, tomo_trai f"Fine Tuning SAM2 on {fib_train} and {fib_val} and {tomo_train} and {tomo_val} for {epochs} epochs" ) print(f"Using SAM2 Config: {sam2_cfg}") - print(f"Using Number of Epochs: {epochs}") print(f"Using Train Zarr: {tomo_train}") print(f"Using Val Zarr: {tomo_val}") print(f"Using Train Zarr: {fib_train}") print(f"Using Val Zarr: {fib_val}") + print(f"Using Number of Epochs: {epochs}") + print(f"Using Batch Size: {batch_size}") print("--------------------------------") - finetune_sam2(tomo_train, fib_train, tomo_val, fib_val, sam2_cfg, epochs) \ No newline at end of file + finetune_sam2(tomo_train, fib_train, tomo_val, fib_val, sam2_cfg, epochs, batch_size) \ No newline at end of file diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py index 864307a..12bd011 100644 --- a/saber/finetune/trainer.py +++ b/saber/finetune/trainer.py @@ -165,7 +165,7 @@ def validate_step(self): batch["images"], # list[H×W×3] or list[H×W] batch["masks"], # list[list[H×W]] amg_kwargs={"points_per_side": 16, "pred_iou_thresh": 0.7, "crop_n_layers": 1}, - top_k=100, + top_k=20, compute_map=True, device=self.device ) @@ -231,13 +231,9 @@ def train(self, num_epochs, best_metric = 'mAP'): self.optimizer.zero_grad() epoch_loss_train += float(losses["loss_total"].detach().cpu()) - import pdb; pdb.set_trace() - # Validate metrics = self.validate_step() - import pdb; pdb.set_trace() - # Print Only on Rank 0 if self.fabric.is_global_zero: print( From 2f050885bf0d4d2a030100da72963d36a0a92bba Mon Sep 17 00:00:00 2001 From: root Date: Mon, 22 Sep 2025 17:37:23 +0000 Subject: [PATCH 06/15] mask decoder is improving, now to focus on the automask generator --- pyproject.toml | 9 +- saber/finetune/dataset.py | 102 +++++++++++++++----- saber/finetune/helper.py | 61 ++++-------- saber/finetune/metrics.py | 196 ++++++++++++-------------------------- saber/finetune/train.py | 2 +- saber/finetune/trainer.py | 128 ++++++++++++++++++------- 6 files changed, 259 insertions(+), 239 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b133761..fd5b2c0 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,14 +34,15 @@ dependencies = [ "starfile", "lightning", "matplotlib", - "opencv-python", - "multiprocess", - "torchmetrics", - "scikit-learn", "ipywidgets", "umap-learn", "torch-ema", + "tensorboard", + "multiprocess", + "torchmetrics", + "scikit-learn", "copick-utils", + "opencv-python", ] [project.optional-dependencies] diff --git a/saber/finetune/dataset.py b/saber/finetune/dataset.py index c90fb55..f320e22 100644 --- a/saber/finetune/dataset.py +++ b/saber/finetune/dataset.py @@ -1,3 +1,4 @@ +from scipy.ndimage import binary_erosion, binary_dilation import saber.finetune.helper as helper from saber.utils import preprocessing from torch.utils.data import Dataset @@ -86,7 +87,7 @@ def __init__(self, # Verbose Flag self.verbose = False - def _estimate_zrange(self, key, band=(0.25, 0.75), threshold=0): + def _estimate_zrange(self, key, band=(0.3, 0.7), threshold=0): """ Returns (z_min, z_max) inclusive bounds for valid slab centers where there is some foreground in the labels. @@ -99,8 +100,8 @@ def _estimate_zrange(self, key, band=(0.25, 0.75), threshold=0): mask = self.tomogram_store[key]['labels/0'][min_offset:max_offset,] vals = mask.sum(axis=(1,2)) vals = np.nonzero(vals)[0] - max_val = vals.max() - self.slab_thickness // 2 + min_offset - min_val = vals.min() + self.slab_thickness // 2 + min_offset + max_val = vals.max() - self.slab_thickness + min_offset + min_val = vals.min() + self.slab_thickness + min_offset return int(min_val), int(max_val) def resample_epoch(self): @@ -198,49 +199,98 @@ def _gen_grid_points(self, h: int, w: int) -> np.ndarray: xx, yy = np.meshgrid(xs, ys) return np.stack([xx.ravel(), yy.ravel()], axis=1) # (G,2) as (x,y) + def _sample_negative_ring(self, comp, other_inst=None, ring=3, max_neg=16, shape=None): + h, w = shape + comp_b = comp.astype(np.bool_) + outer = binary_dilation(comp_b, iterations=ring) & (~comp_b) + if other_inst is not None: + # avoid putting negatives inside any other instance + outer = outer & (~other_inst.astype(np.bool_)) + ys, xs = np.where(outer) + if len(xs) == 0: + return np.zeros((0, 2), np.float32) + k = min(max_neg, len(xs)) + idx = self._rng.choice(len(xs), size=k, replace=False) + return np.stack([xs[idx].astype(np.float32), ys[idx].astype(np.float32)], axis=1) + def _sample_points_in_mask( self, comp: np.ndarray, grid_points: np.ndarray, - jitter_px: float = 0.0, - shape: tuple[int, int] = None, + shape: tuple[int, int], + jitter_px: float = 1.0, + k_cap: int = 300, + boundary_frac: float = 0.35, ) -> np.ndarray: """ - Pick clicks from a regular grid, restricted to those that land inside `comp`. - Optionally jitter each kept grid point by up to ±jitter_px. + Pick informative clicks from a dense grid: + - favor boundary points + - cap count ~sqrt(area) up to k_cap Returns float32 array [K,2] in (x,y). """ - - if comp.sum() == 0: - return np.zeros((0, 2), dtype=np.float32) - h, w = shape + if comp.sum() == 0: + return np.zeros((0, 2), np.float32) + + # ----- cast to boolean for morphology ----- + comp_b = comp.astype(np.bool_) # important: avoid TypeError with '^' - # round grid coords to nearest pixel for mask lookup + # grid → nearest pixel indices for inside/boundary tests gx = np.clip(np.rint(grid_points[:, 0]).astype(int), 0, w - 1) gy = np.clip(np.rint(grid_points[:, 1]).astype(int), 0, h - 1) - inside = comp[gy, gx] > 0 - cand = grid_points[inside] # (M,2) subset of the grid - + inside = comp_b[gy, gx] + cand = grid_points[inside] if cand.shape[0] == 0: - return np.zeros((0, 2), dtype=np.float32) + return np.zeros((0, 2), np.float32) + + # ----- boundary mask (inner ring) ----- + eroded = binary_erosion(comp_b, iterations=2) + boundary_b = np.logical_and(comp_b, np.logical_not(eroded)) # same as comp ^ eroded on booleans - k = int(self._rng.randint(self.k_min, self.k_max + 1)) - k = min(k, cand.shape[0]) + on_b = boundary_b[gy, gx] & inside + cand_b = grid_points[on_b] + cand_i = grid_points[inside & (~on_b)] + + # target k ~ c * area but capped + area = float(comp_b.sum()) + k_target = int( + min( k_cap, max(24, area * 0.12) ) + ) + + kb = int(boundary_frac * k_target) + ki = k_target - kb + + rng = self._rng + take_b = min(kb, len(cand_b)) + take_i = min(ki, len(cand_i)) + + # if boundary is too small, backfill from interior + if take_b + take_i == 0: + return np.zeros((0, 2), np.float32) + + if take_b: + idx_b = rng.choice(len(cand_b), size=take_b, replace=False) + pts_b = cand_b[idx_b] + else: + pts_b = np.zeros((0, 2), np.float32) + + if take_i: + idx_i = rng.choice(len(cand_i), size=take_i, replace=False) + pts_i = cand_i[idx_i] + else: + pts_i = np.zeros((0, 2), np.float32) - idx = self._rng.choice(cand.shape[0], size=k, replace=False) - pts = cand[idx].astype(np.float32) + pts = np.concatenate([pts_b, pts_i], axis=0).astype(np.float32) - # optional jitter to avoid perfectly regular patterns - if jitter_px > 0: - jitter = self._rng.uniform(-jitter_px, jitter_px, size=pts.shape).astype(np.float32) - pts = pts + jitter - # keep inside image bounds + # jitter a touch to avoid perfect grid regularity + if jitter_px > 0 and pts.shape[0] > 0: + jitter = rng.uniform(-jitter_px, jitter_px, size=pts.shape).astype(np.float32) + pts += jitter pts[:, 0] = np.clip(pts[:, 0], 0, w - 1) pts[:, 1] = np.clip(pts[:, 1], 0, h - 1) - return pts + return pts def _package_image_item(self, image_2d: np.ndarray, diff --git a/saber/finetune/helper.py b/saber/finetune/helper.py index 5a994c2..de70051 100644 --- a/saber/finetune/helper.py +++ b/saber/finetune/helper.py @@ -1,8 +1,8 @@ from saber.visualization.classifier import get_colors, add_masks import matplotlib.patches as patches import matplotlib.pyplot as plt +import cv2, torch, csv, os import numpy as np -import cv2, torch def mask_to_box(mask: np.ndarray) -> np.ndarray | None: """xyxy box from a binary mask (H,W) in {0,1}.""" @@ -115,27 +115,6 @@ def _resize_one(s, max_res: int): "boxes": bxs, } -# def collate_autoseg(batch, max_res: int = 768): -# # batch: list of dicts from _package_image_item - -# output = { "images": [],"masks": [],"points": [],"labels": [],"boxes": []} -# for b in batch: -# results = _resize_inputs(b, max_res) -# for k, v in results.items(): -# output[k].append(v) - -# return output - -# def _resize_inputs(b, max_res): -# h, w = b["image"].shape[:2] -# if h > max_res or w > max_res: -# b["image"] = cv2.resize(b["image"], (max_res, max_res)) -# b["masks"] = [cv2.resize(m, (max_res, max_res), interpolation=cv2.INTER_NEAREST) for m in b["masks"]] -# b["points"] = [p * max_res / h for p in b["points"]] -# b["labels"] = [l * max_res / h for l in b["labels"]] -# b["boxes"] = [b * max_res / h for b in b["boxes"]] -# return b - def _to_numpy_mask_stack(masks): """ Accepts list/tuple of tensors or np arrays shaped [H,W]; @@ -204,23 +183,21 @@ def visualize_item_with_points(image, masks, points, boxes=None, ax.set_title(title) ax.axis("off") plt.tight_layout() - # plt.show() - -# def sample_points_in_mask(mask: np.ndarray, k_min: int = 1, k_max: int = 10) -> np.ndarray: -# """ -# Uniformly sample a random number of clicks in [k_min, k_max] from a binary component mask. -# Always inside the mask; capped by the number of foreground pixels. -# Returns float32 array of shape [K, 2] in (x, y). -# """ -# ys, xs = np.nonzero(mask) # foreground coordinates -# n = xs.size -# if n == 0: -# return np.zeros((0, 2), dtype=np.float32) - -# k = int(_rng.randint(k_min, k_max + 1)) # inclusive range -# k = min(k, n) # cap by available pixels - -# # sample without replacement so clicks are distinct -# idx = self._rng.choice(n, size=k, replace=False) -# pts = np.stack([xs[idx], ys[idx]], axis=1).astype(np.float32) # (x, y) -# return pts \ No newline at end of file + +def save_training_log(results, outdir="results"): + + # CSV (epoch-aligned, pad with blanks if needed) + path = os.path.join(outdir, "metrics.csv") + is_new = not os.path.exists(path) + + with open(path, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["epoch", "lr", "train_loss", "val_loss", "ABIoU"]) + if is_new: + writer.writeheader() + writer.writerow({ + "epoch": int(results['epoch']), + "lr": f"{results['lr']:.1e}", + "train_loss": float(results['train']['loss']), + "val_loss": float(results['loss']), + "ABIoU": float(results['ABIoU']), + }) \ No newline at end of file diff --git a/saber/finetune/metrics.py b/saber/finetune/metrics.py index 13bf70c..df7b6d3 100644 --- a/saber/finetune/metrics.py +++ b/saber/finetune/metrics.py @@ -1,95 +1,46 @@ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator -import torch.nn.functional as F +from contextlib import nullcontext import numpy as np import torch - - -# --------------------- IoU / metric helpers --------------------- +# --------------------- IoU / ABIoU --------------------- def _mask_iou(a, b, eps=1e-6): """ - a: [Na,H,W] {0,1}; b: [Nb,H,W] {0,1} - returns IoU matrix [Na,Nb] + a: [Na,H,W] {0,1}; b: [Nb,H,W] {0,1} -> IoU [Na,Nb] """ if a.numel() == 0 or b.numel() == 0: - return torch.zeros((a.shape[0], b.shape[0]), device=a.device) + dev = a.device if a.numel() > 0 else (b.device if b.numel() > 0 else torch.device("cpu")) + Na = a.shape[0] if a.numel() > 0 else 0 + Nb = b.shape[0] if b.numel() > 0 else 0 + return torch.zeros((Na, Nb), device=dev, dtype=torch.float32) + a = a.float() b = b.float() - inter = torch.einsum("nhw,mhw->nm", a, b) - ua = a.sum(dim=(1,2)).unsqueeze(1) - ub = b.sum(dim=(1,2)).unsqueeze(0) - union = ua + ub - inter + eps - return inter / union + inter = torch.einsum("nhw,mhw->nm", a, b) # [Na,Nb] + ua = a.sum(dim=(1,2))[:, None] # [Na,1] + ub = b.sum(dim=(1,2))[None, :] # [1,Nb] + + union = ua + ub - inter + eps # [Na,Nb] + return inter / union def _abiou(proposals, gts): """ - Average Best IoU (higher is better). + Average Best IoU (coverage metric). proposals: [Np,H,W] {0,1}, gts: [Ng,H,W] {0,1} """ if gts.numel() == 0 and proposals.numel() == 0: - return torch.tensor(1.0, device=proposals.device if proposals.is_cuda else gts.device) - if gts.numel() == 0: - return torch.tensor(0.0, device=proposals.device) - if proposals.numel() == 0: - return torch.tensor(0.0, device=gts.device) - iou = _mask_iou(gts, proposals) # [Ng,Np] - best = iou.max(dim=1).values # [Ng] + dev = proposals.device if proposals.numel() > 0 else gts.device + return torch.tensor(1.0, device=dev, dtype=torch.float32) + if gts.numel() == 0 or proposals.numel() == 0: + dev = proposals.device if proposals.numel() > 0 else gts.device + return torch.tensor(0.0, device=dev, dtype=torch.float32) + iou = _mask_iou(gts, proposals) # [Ng,Np] + best = iou.max(dim=1).values # [Ng] return best.mean() - -def _ap_at_threshold(proposals, scores, gts, thr=0.5): - """ - Greedy one-to-one matching by descending score at a single IoU threshold. - proposals: [Np,H,W] {0,1} - scores: [Np] (higher is better) - gts: [Ng,H,W] {0,1} - """ - # Degenerate cases - if gts.numel() == 0 and proposals.numel() == 0: - return torch.tensor(1.0, device=scores.device) - if proposals.numel() == 0: - return torch.tensor(0.0, device=scores.device) - - order = scores.argsort(descending=True) - props = proposals[order] - - matched_gt = torch.zeros((gts.shape[0],), dtype=torch.bool, device=gts.device) - tp, fp = [], [] - for i in range(props.shape[0]): - if gts.numel() == 0: - fp.append(1); tp.append(0); continue - ious = _mask_iou(props[i:i+1], gts)[0] # [Ng] - j = torch.argmax(ious) - if ious[j] >= thr and not matched_gt[j]: - matched_gt[j] = True - tp.append(1); fp.append(0) - else: - tp.append(0); fp.append(1) - - tp = torch.tensor(tp, device=scores.device).cumsum(0) - fp = torch.tensor(fp, device=scores.device).cumsum(0) - precision = tp / (tp + fp).clamp(min=1) - recall = tp / max(gts.shape[0], 1) - - # Precision envelope + trapezoidal integral over recall - mrec, idx = torch.sort(recall) - mpre = precision[idx] - for k in range(mpre.shape[0] - 2, -1, -1): - mpre[k] = torch.maximum(mpre[k], mpre[k+1]) - return torch.trapz(mpre, mrec) - - -def _map(proposals, scores, gts, thresholds=(0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95)): - """ - Mean AP across IoU thresholds (COCO-style, class-agnostic). - """ - aps = [ _ap_at_threshold(proposals, scores, gts, thr=t) for t in thresholds ] - return torch.stack(aps).mean() - - -# --------------------- Main evaluator --------------------- +# --------------------- ABIoU evaluator --------------------- @torch.no_grad() def automask_metrics( @@ -98,42 +49,25 @@ def automask_metrics( gt_masks_per_image, *, amg_kwargs=None, - top_k=None, # optionally cap #proposals/image after AMG (for speed) - compute_map=True, # if False, only ABIoU and AP@0.5 - map_thresholds=(0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95), + top_k=20, device=None, + autocast_ctx=None, # callable -> context manager (e.g., trainer.autocast); if None, no autocast ): """ - Run SAM2AutomaticMaskGenerator on each image and compute ABIoU, AP@0.5 (and mAP if requested). + Run SAM2AutomaticMaskGenerator and compute ABIoU only. Args ---- - sam2_model_or_predictor : the fine-tuned SAM2 object. If you passed a predictor wrapper, - we'll try to hand its `.model` to SAM2AutomaticMaskGenerator. - images : list of images; each can be: - - HxWx3 uint8 NumPy (preferred for AMG), or - - torch.Tensor (H,W) or (H,W,3) float in [0,1] or [0,255] - gt_masks_per_image : list[list[H x W]]; elements can be NumPy or torch; non-zero = foreground - amg_kwargs : dict of params forwarded to SAM2AutomaticMaskGenerator(...) - e.g. {'points_per_side': 16, 'points_per_batch': 64, 'pred_iou_thresh': 0.7, ...} - top_k : optional int to keep only top-K proposals per image by score after AMG - compute_map : whether to compute mAP over multiple IoU thresholds - map_thresholds : tuple of IoU thresholds - device : torch device for metrics (defaults to 'cuda' if available) - - Returns - ------- - summary : dict with aggregated metrics: - { - 'ABIoU': float, - 'AP50': float, - 'mAP': float or None, - 'num_images': int, - 'per_image': [ {'ABIoU':..., 'AP50':..., 'mAP':... or None, 'num_props': int, 'num_gt': int}, ... ] - } + sam2_model_or_predictor : SAM2 model or predictor (we’ll use .model if present) + images : list of images: HxW or HxWx3; uint8 NumPy preferred, but float is fine + gt_masks_per_image : list[list[H x W]]; elements may be NumPy or torch; non-zero = foreground + amg_kwargs : dict forwarded to SAM2AutomaticMaskGenerator + top_k : keep only top-K proposals per image by AMG score (optional) + device : torch device for tensors (defaults to 'cuda' if available) + autocast_ctx : callable returning a context manager for mixed precision during AMG forward only """ - # local alias to avoid confusion with user's np - _amg_kwargs = dict( + # Defaults for AMG (tweak as you like) + _amg = dict( points_per_side=32, points_per_batch=128, pred_iou_thresh=0.7, @@ -143,79 +77,73 @@ def automask_metrics( crop_n_points_downscale_factor=2, box_nms_thresh=0.7, use_m2m=False, - multimask_output=False, + multimask_output=True, ) if amg_kwargs: - _amg_kwargs.update(amg_kwargs) - - # Figure out what to pass into the generator: underlying model vs predictor + _amg.update(amg_kwargs) model_for_amg = getattr(sam2_model_or_predictor, "model", sam2_model_or_predictor) + mask_generator = SAM2AutomaticMaskGenerator(model=model_for_amg, **_amg) - mask_generator = SAM2AutomaticMaskGenerator(model=model_for_amg, **_amg_kwargs) - + # Determine Device if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - per_image = [] - abiou_vals, map_vals = [], [] + # AutoCast + ac = autocast_ctx if autocast_ctx is not None else (lambda: nullcontext()) + # Per-image loop + per_image, abiou_vals = [], [] for img, gt_list in zip(images, gt_masks_per_image): - - # ---------- run AMG ---------- - props = mask_generator.generate(img) # list of dicts with 'segmentation' and 'predicted_iou'/'score' + # -------- ensure AMG-friendly uint8 numpy image -------- + if isinstance(img, torch.Tensor): + img_np = img.detach().cpu().numpy() + else: + img_np = img + H, W = img_np.shape[:2] + + # -------- AMG forward under autocast (fast path) -------- + with ac(): + props = mask_generator.generate(img_np) # list of dicts - H, W, _ = img.shape if len(props) == 0: prop_masks = torch.zeros((0, H, W), dtype=torch.uint8, device=device) - prop_scores = torch.zeros((0,), dtype=torch.float32, device=device) else: - # sort by predicted IoU score, keep top_k if requested + # Sort by predicted_iou/score and optionally keep top_k def _score(d): return float(d.get("predicted_iou", d.get("score", 0.0))) - props = sorted(props, key=_score, reverse=True) + props.sort(key=_score, reverse=True) if top_k is not None and top_k > 0: props = props[:top_k] masks_np = [(p["segmentation"] > 0).astype(np.uint8) for p in props] - scores_np = [_score(p) for p in props] prop_masks = torch.from_numpy(np.stack(masks_np, axis=0)).to(device=device) - prop_scores = torch.tensor(scores_np, device=device, dtype=torch.float32) - # ---------- stack GTs ---------- + # -------- stack GTs -------- if len(gt_list) == 0: gt_masks = torch.zeros((0, H, W), dtype=torch.uint8, device=device) else: gts = [] for g in gt_list: - g_np = g.detach().cpu().numpy() if isinstance(g, torch.Tensor) else g + if isinstance(g, torch.Tensor): + g_np = g.detach().cpu().numpy() + else: + g_np = g gts.append((g_np > 0).astype(np.uint8)) gt_masks = torch.from_numpy(np.stack(gts, axis=0)).to(device=device) - # ---------- metrics ---------- + # -------- ABIoU in float32 (stable) -------- abiou = _abiou(prop_masks, gt_masks) - if compute_map: - mAP = _map(prop_masks, prop_scores, gt_masks, thresholds=map_thresholds) - else: - mAP = None - per_image.append({ "ABIoU": float(abiou.detach().cpu()), - "mAP": (float(mAP.detach().cpu()) if mAP is not None else None), "num_props": int(prop_masks.shape[0]), "num_gt": int(gt_masks.shape[0]), }) - abiou_vals.append(abiou) - if compute_map: - map_vals.append(mAP) - # aggregate ABIoU = torch.stack(abiou_vals).mean().item() if abiou_vals else 0.0 - mAP = (torch.stack(map_vals).mean().item() if compute_map and map_vals else None) - return { "ABIoU": ABIoU, - "mAP": mAP, + 'ABIoU_per_image': abiou_vals, "num_images": len(per_image), "per_image": per_image, } diff --git a/saber/finetune/train.py b/saber/finetune/train.py index 021407c..36dd3e4 100644 --- a/saber/finetune/train.py +++ b/saber/finetune/train.py @@ -34,7 +34,7 @@ def finetune_sam2( train_loader = DataLoader( AutoMaskDataset(tomo_train, fib_train), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) val_loader = DataLoader( AutoMaskDataset(tomo_val, fib_val), batch_size=batch_size, shuffle=False, - num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) if (tomo_val or fib_val) else None + num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) if (tomo_val or fib_val) else train_loader # Initialize trainer and train trainer = SAM2FinetuneTrainer( predictor, train_loader, val_loader ) diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py index 12bd011..ac954b8 100644 --- a/saber/finetune/trainer.py +++ b/saber/finetune/trainer.py @@ -1,3 +1,5 @@ +from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR +from saber.finetune.helper import save_training_log from saber.finetune.metrics import automask_metrics from saber.finetune.losses import MultiMaskIoULoss from lightning import fabric @@ -20,22 +22,24 @@ def __init__(self, predictor, train_loader, val_loader): # Initialize the optimizer and dataloaders self.num_gpus = torch.cuda.device_count() - optimizer = torch.optim.AdamW(params, weight_decay=4e-5) + optimizer = torch.optim.AdamW(params, weight_decay=1e-5) if self.num_gpus > 1: self.fabric = fabric.Fabric(accelerator="cuda", strategy="ddp", devices=self.num_gpus) self.fabric.launch() self.predictor.model, self.optimizer = self.fabric.setup(self.predictor.model,optimizer) + self.predictor.model.mark_forward_method('forward_image') self.autocast = self.fabric.autocast self.use_fabric = True else: self.optimizer = optimizer self.use_fabric = False - self.autocast = torch.cuda.amp.autocast + def _autocast(): + return torch.autocast(device_type="cuda", enabled=torch.cuda.is_available()) + self.autocast = _autocast self.device = next(self.predictor.model.parameters()).device - if val_loader is None and self.use_fabric: - self.train_loader = self.fabric.setup_dataloaders(train_loader) - elif self.use_fabric and val_loader is not None: + # Setup dataloaders + if self.use_fabric: self.train_loader, self.val_loader = self.fabric.setup_dataloaders(train_loader, val_loader) else: self.train_loader, self.val_loader = train_loader, val_loader @@ -43,8 +47,8 @@ def __init__(self, predictor, train_loader, val_loader): # Initialize the loss function self.focal_alpha = 0.25 self.focal_gamma = 2.0 - self.supervise_all_iou = False - self.iou_use_l1_loss = True + self.supervise_all_iou = True + self.iou_use_l1_loss = False self.predict_multimask = True # Initialize the use_boxes flag @@ -54,6 +58,11 @@ def __init__(self, predictor, train_loader, val_loader): self.save_dir = 'results' os.makedirs(self.save_dir, exist_ok=True) + @property + def is_global_zero(self): + # True on single-process runs; Fabric guards inside when present + return (not self.use_fabric) or (self.use_fabric is not None and self.fabric.is_global_zero) + @torch.no_grad() def _stack_image_embeddings_from_predictor(self): """ @@ -100,9 +109,9 @@ def forward_step(self, batch): # 3) Pad clicks to (N,P,2) and (N,P) P = max(p.shape[0] for p in pts_all) - pts_pad = torch.zeros((N, P, 2), device=self.device, dtype=torch.float32) - lbl_pad = torch.zeros((N, P), device=self.device, dtype=torch.float32) - for i, (p, l) in enumerate(zip(pts_all, lbl_all)): + pts_pad = torch.zeros((N, P, 2), device=self.device) + lbl_pad = torch.full((N, P), -1.0, device=self.device) # <- ignore + for i,(p,l) in enumerate(zip(pts_all, lbl_all)): pts_pad[i, :p.shape[0]] = p lbl_pad[i, :l.shape[0]] = l @@ -154,95 +163,150 @@ def validate_step(self): # Local accumulators (weighted by number of images in each call) abiou_sum = torch.tensor(0.0, device=self.device) - map_sum = torch.tensor(0.0, device=self.device) + loss_sum = torch.tensor(0.0, device=self.device) n_imgs = torch.tensor(0.0, device=self.device) + n_inst = torch.tensor(0.0, device=self.device) # Each rank iterates only its shard (Fabric sets DistributedSampler for you) for batch in self.val_loader: + + # Compute Loss on decoder outputs + out = self.forward_step(batch) + if out[0] is None: + continue # no instances in this batch + prd_masks, prd_scores, gt_masks = out[:3] + batch_n = torch.tensor(float(gt_masks.shape[0]), device=self.device) + + with self.autocast(): + losses = self.loss_fn(prd_masks, prd_scores, gt_masks) + # convert to sum over instances + loss_sum += float(losses["loss_total"].detach().cpu()) * batch_n + n_inst += batch_n + # Compute metrics on THIS batch only (keeps memory small & parallel) m = automask_metrics( self.predictor, # predictor or predictor.model (your function supports either) batch["images"], # list[H×W×3] or list[H×W] batch["masks"], # list[list[H×W]] - amg_kwargs={"points_per_side": 16, "pred_iou_thresh": 0.7, "crop_n_layers": 1}, top_k=20, - compute_map=True, - device=self.device + device=self.device, + autocast_ctx=self.autocast, ) # Weight by number of images so we can average correctly later num = float(m["num_images"]) abiou_sum += torch.tensor(m["ABIoU"] * num, device=self.device) - if m["mAP"] is not None: - map_sum += torch.tensor(m["mAP"] * num, device=self.device) n_imgs += torch.tensor(num, device=self.device) # Global reduction (sum across all ranks) if self.use_fabric: + loss_sum = self.fabric.all_reduce(loss_sum, reduce_op="sum") abiou_sum = self.fabric.all_reduce(abiou_sum, reduce_op="sum") - map_sum = self.fabric.all_reduce(map_sum, reduce_op="sum") n_imgs = self.fabric.all_reduce(n_imgs, reduce_op="sum") + n_inst = self.fabric.all_reduce(n_inst, reduce_op="sum") else: abiou_sum = abiou_sum.sum() - map_sum = map_sum.sum() n_imgs = n_imgs.sum() # Avoid divide-by-zero denom = max(n_imgs.item(), 1.0) + loss_denom = max(n_inst.item(), 1.0) return { + "loss": (loss_sum / loss_denom).item(), "ABIoU": (abiou_sum / denom).item(), - "mAP": (map_sum / denom).item(), "num_images": int(denom), } - def train(self, num_epochs, best_metric = 'mAP'): + def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 10): """ Fine Tune SAM2 on the given data. """ # Initialize the loss function self.loss_fn = MultiMaskIoULoss( - weight_dict={"loss_mask": 1.0, "loss_dice": 1.0, "loss_iou": 0.05}, + weight_dict={"loss_mask": 1.0, "loss_dice": 1.0, "loss_iou": 0.15}, focal_alpha=self.focal_alpha, focal_gamma=self.focal_gamma, supervise_all_iou=self.supervise_all_iou, iou_use_l1_loss=self.iou_use_l1_loss ) + # Cosine scheduler w/Warmup ---- + warmup_epochs = max(int(0.05 * num_epochs), 1) + self.warmup_sched = LinearLR(self.optimizer, start_factor=1e-3, total_iters=warmup_epochs) + self.cosine_sched = CosineAnnealingLR(self.optimizer, T_max=(num_epochs - warmup_epochs), eta_min=1e-6) + self.scheduler = SequentialLR(self.optimizer, [self.warmup_sched, self.cosine_sched], milestones=[warmup_epochs]) + + # Progress bar only on rank 0 + if self.is_global_zero: + pbar = tqdm(total=num_epochs, desc='Fine Tuning SAM2', unit='epoch', + leave=True, dynamic_ncols=True) + else: + pbar = None + best_metric_value = float('-inf') self.optimizer.zero_grad() - for epoch in tqdm(range(num_epochs), desc="Training", unit="epoch"): + for epoch in range(num_epochs): # Train epoch_loss_train = 0 self.predictor.model.train() - self.train_loader.dataset.resample_epoch() + if (epoch+1) % resample_frequency == 0: + self.train_loader.dataset.resample_epoch() for batch in self.train_loader: out = self.forward_step(batch) if out[0] is None: continue - prd_masks, prd_scores, gt_masks, _ = out + prd_masks, prd_scores, gt_masks = out[:3] with self.autocast(): losses = self.loss_fn(prd_masks, prd_scores, gt_masks) if self.use_fabric: self.fabric.backward(losses['loss_total']) else: losses['loss_total'].backward() + + # (optional) gradient clip: + if self.use_fabric: + # norm-based clipping (L2) on all params in the optimizer + self.fabric.clip_gradients( + self.predictor.model, + self.optimizer, + max_norm=1.0, + norm_type=2.0 + ) + else: + torch.nn.utils.clip_grad_norm_( + self.predictor.model.parameters(), + 1.0, norm_type=2.0) + self.optimizer.step() self.optimizer.zero_grad() epoch_loss_train += float(losses["loss_total"].detach().cpu()) + # Learning Rate Scheduler + self.scheduler.step() + # Validate metrics = self.validate_step() # Print Only on Rank 0 - if self.fabric.is_global_zero: - print( - f"Epoch {epoch+1}/{num_epochs} " - f"Loss={epoch_loss_train/len(self.train_loader):.5f} " - f"mAP={metrics['mAP']:.4f} - ABIoU={metrics['ABIoU']:.4f} " - ) - + if self.is_global_zero: + # Print Metrics + metrics['train'] = {'loss': epoch_loss_train/len(self.train_loader)} + pbar.set_postfix({ + "train_loss": f"{metrics['train']['loss']:.4f}", + "val_loss": f"{metrics['loss']:.4f}", + "ABIoU": f"{metrics['ABIoU']:.4f}", + }) + pbar.update(1) + + # Save Training Log + metrics['epoch'] = epoch + metrics['lr'] = self.scheduler.get_last_lr()[0] + save_training_log(metrics, self.save_dir) + + # Save Model if best metric is achieved if metrics[best_metric] > best_metric_value: best_metric_value = metrics[best_metric] - torch.save(self.predictor.model.state_dict(), f"{self.save_dir}/best_model.pth") + ckpt = {"model": self.predictor.model.state_dict()} + torch.save(ckpt, f"{self.save_dir}/best_model.pth") print(f"Best {best_metric} saved!") \ No newline at end of file From a4684e05dda7d4fc294ea7105d1f83cbc0a32ce2 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 23 Sep 2025 16:14:07 +0000 Subject: [PATCH 07/15] working fine tuning script --- saber/finetune/helper.py | 103 +++++++++++++++++++++++++++++++++- saber/finetune/losses.py | 88 ++++++++++++++--------------- saber/finetune/metrics.py | 3 + saber/finetune/trainer.py | 113 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 252 insertions(+), 55 deletions(-) diff --git a/saber/finetune/helper.py b/saber/finetune/helper.py index de70051..f57634d 100644 --- a/saber/finetune/helper.py +++ b/saber/finetune/helper.py @@ -200,4 +200,105 @@ def save_training_log(results, outdir="results"): "train_loss": float(results['train']['loss']), "val_loss": float(results['loss']), "ABIoU": float(results['ABIoU']), - }) \ No newline at end of file + }) + +######################################################################################## + +def _orig_hw_tuple(pred): + ohw = pred._orig_hw + # Cases seen in SAM/SAM2 wrappers: (H,W), [H,W], [(H,W)], possibly numpy ints + if isinstance(ohw, list): + # list of tuples -> pick the last tuple + if len(ohw) and isinstance(ohw[-1], (list, tuple)): + h, w = ohw[-1] + else: + # flat [H, W] + h, w = ohw[0], ohw[1] + elif isinstance(ohw, tuple): + h, w = ohw + else: + # e.g., numpy array-like + h, w = ohw[0], ohw[1] + return (int(h), int(w)) + +@torch.no_grad() +def infer_on_single_image( + predictor, + image, + inst_points, # list[Tensor(Pi,2)] + inst_labels, # list[Tensor(Pi)] + inst_masks=None, # optional list[H,W] + inst_boxes=None, # list[Tensor(4)] or None + use_boxes=True, + predict_multimask=False, + device="cuda", +): + # 1) Encode image once + predictor.set_image(image) + + # 2) Normalize cached features to [1,C,H',W'] + def _to_batched_4d(x): + if isinstance(x, (list, tuple)): + x = x[0] + if x.dim() == 3: x = x.unsqueeze(0) + return x.to(device) + + image_embeddings = _to_batched_4d(predictor._features["image_embed"]) # [1,C,H′,W′] + high_res_features = [_to_batched_4d(lvl) for lvl in predictor._features["high_res_feats"]] + + # 3) Pack prompts to (N,P,2) / (N,P) + pts_all = [torch.as_tensor(p, device=device, dtype=torch.float32) for p in inst_points] + lbl_all = [torch.as_tensor(l, device=device, dtype=torch.float32) for l in inst_labels] + N = len(pts_all) + if N == 0: + return None, None, None + + P = max(p.shape[0] for p in pts_all) + pts_pad = torch.zeros((N, P, 2), device=device) + lbl_pad = torch.full((N, P), -1.0, device=device) + for i,(p,l) in enumerate(zip(pts_all, lbl_all)): + pts_pad[i, :p.shape[0]] = p + lbl_pad[i, :l.shape[0]] = l + + boxes = None + if use_boxes and inst_boxes: + boxes = torch.stack([torch.as_tensor(bx, device=device, dtype=torch.float32) for bx in inst_boxes], dim=0) + + # 4) Prompt encoding + _, unnorm_coords, labels, _ = predictor._prep_prompts( + pts_pad, lbl_pad, box=boxes, mask_logits=None, normalize_coords=True + ) + sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder( + points=(unnorm_coords, labels), + boxes=boxes if use_boxes else None, + masks=None, + ) + + # 5) Decode with repeat_image=True ✅ no manual feature tiling + low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder( + image_embeddings=image_embeddings, # [1,C,H′,W′] + image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, # [N, ...] + dense_prompt_embeddings=dense_embeddings, # [N, ...] + multimask_output=predict_multimask, + repeat_image=True, # <-- let decoder repeat internally + high_res_features=high_res_features, # each [1,C,H′,W′] + ) + + # 6) Upscale + optional stack GT + def _orig_hw_tuple(pred): + ohw = pred._orig_hw + if isinstance(ohw, list) and len(ohw) and isinstance(ohw[-1], (list, tuple)): + return (int(ohw[-1][0]), int(ohw[-1][1])) + if isinstance(ohw, tuple): + return (int(ohw[0]), int(ohw[1])) + return tuple(int(v) for v in ohw) + + out_hw = _orig_hw_tuple(predictor) # (H, W) e.g., (928, 960) + prd_masks = predictor._transforms.postprocess_masks(low_res_masks, out_hw) # [N,K,H,W] + + gt_masks = None + if inst_masks: + gt_masks = torch.stack([torch.as_tensor(m, device=device).float() for m in inst_masks], dim=0) + + return prd_masks, prd_scores, gt_masks \ No newline at end of file diff --git a/saber/finetune/losses.py b/saber/finetune/losses.py index 4f9bb98..f41cd63 100644 --- a/saber/finetune/losses.py +++ b/saber/finetune/losses.py @@ -45,68 +45,60 @@ def forward(self, prd_masks, prd_scores, gt_masks): prd_scores: [N, K] predicted IoU scores gt_masks: [N, H, W] float {0,1} """ + device = prd_masks.device N, K, H, W = prd_masks.shape + gt_masks = gt_masks.to(prd_masks.dtype) - # compute per-proposal losses - loss_mask_k, loss_dice_k = [], [] - for k in range(K): - l_focal = focal_loss_from_logits( - prd_masks[:, k], gt_masks, - alpha=self.focal_alpha, gamma=self.focal_gamma - ) # scalar over batch - l_dice = dice_loss_from_logits(prd_masks[:, k], gt_masks) # [N] - loss_mask_k.append(l_focal.expand_as(l_dice)) - loss_dice_k.append(l_dice) - loss_mask_k = torch.stack(loss_mask_k, dim=1) # [N,K] - loss_dice_k = torch.stack(loss_dice_k, dim=1) # [N,K] + # ---- 1) Choose proposal by *true IoU* (no grad) ------------------------- + with torch.no_grad(): + probs_k = prd_masks.sigmoid() # [N,K,H,W] + pred_bin_k = (probs_k > 0.5).to(gt_masks.dtype) # hard masks + gt_k = gt_masks[:, None].expand_as(pred_bin_k) # [N,K,H,W] + inter = (pred_bin_k * gt_k).sum(dim=(2, 3)) # [N,K] + union = (pred_bin_k + gt_k - pred_bin_k * gt_k)\ + .sum(dim=(2, 3)).clamp_min(1e-6) # [N,K] + true_iou_k = inter / union # [N,K] + best_idx = true_iou_k.argmax(dim=1) # [N] - # combine to pick best proposal per instance - combo = (self.weight_dict["loss_mask"] * loss_mask_k + - self.weight_dict["loss_dice"] * loss_dice_k) - best_idx = combo.argmin(dim=1) # [N] - row = torch.arange(N, device=prd_masks.device) + row = torch.arange(N, device=device) + logits_star = prd_masks[row, best_idx] # [N,H,W] + score_star = prd_scores[row, best_idx] # [N] + true_iou_star = true_iou_k[row, best_idx].detach() # [N] - # select best proposal losses - loss_mask = loss_mask_k[row, best_idx].mean() - loss_dice = loss_dice_k[row, best_idx].mean() + # ---- 2) Mask losses on the chosen proposal ------------------------------ + l_focal = focal_loss_from_logits( + logits_star, gt_masks, + alpha=self.focal_alpha, gamma=self.focal_gamma + ).mean() # scalar - # IoU calibration loss - with torch.no_grad(): - probs = torch.sigmoid(prd_masks[row, best_idx]) # [N,H,W] - pred_bin = (probs > 0.5).float() - inter = (gt_masks * pred_bin).sum(dim=(1, 2)) - union = gt_masks.sum(dim=(1, 2)) + pred_bin.sum(dim=(1, 2)) - inter + 1e-6 - true_iou = inter / union # [N] + l_dice = dice_loss_from_logits(logits_star, gt_masks).mean() # scalar - if self.supervise_all_iou: - # supervise all proposals - iou_targets = [] - for k in range(K): - probs = torch.sigmoid(prd_masks[:, k]) - pred_bin = (probs > 0.5).float() - inter = (gt_masks * pred_bin).sum(dim=(1, 2)) - union = gt_masks.sum(dim=(1, 2)) + pred_bin.sum(dim=(1, 2)) - inter + 1e-6 - iou_targets.append(inter / union) - iou_targets = torch.stack(iou_targets, dim=1) # [N,K] - if self.iou_use_l1_loss: - loss_iou = F.l1_loss(prd_scores, iou_targets) - else: - loss_iou = F.mse_loss(prd_scores, iou_targets) + # ---- 3) IoU head regression on the chosen proposal ---------------------- + if self.iou_use_l1_loss: + l_iou = F.smooth_l1_loss(score_star, true_iou_star) else: - score_best = prd_scores[row, best_idx] # [N] + l_iou = F.mse_loss(score_star, true_iou_star) + + # (optional) small regularizer on *all* proposals to stabilize ranking + if self.supervise_all_iou: if self.iou_use_l1_loss: - loss_iou = F.l1_loss(score_best, true_iou) + l_iou_all = F.smooth_l1_loss(prd_scores, true_iou_k.detach()) else: - loss_iou = F.mse_loss(score_best, true_iou) + l_iou_all = F.mse_loss(prd_scores, true_iou_k.detach()) + l_iou = l_iou + 0.1 * l_iou_all # small weight; tune 0.05–0.2 + + # ---- 4) Weighted sum ----------------------------------------------------- + loss_mask = l_focal + loss_dice = l_dice + loss_iou = l_iou - # weighted sum total_loss = (self.weight_dict["loss_mask"] * loss_mask + - self.weight_dict["loss_dice"] * loss_dice + - self.weight_dict["loss_iou"] * loss_iou) + self.weight_dict["loss_dice"] * loss_dice + + self.weight_dict["loss_iou"] * loss_iou) return { "loss_mask": loss_mask, "loss_dice": loss_dice, - "loss_iou": loss_iou, + "loss_iou": loss_iou, "loss_total": total_loss, } \ No newline at end of file diff --git a/saber/finetune/metrics.py b/saber/finetune/metrics.py index df7b6d3..c7d9bc8 100644 --- a/saber/finetune/metrics.py +++ b/saber/finetune/metrics.py @@ -81,6 +81,7 @@ def automask_metrics( ) if amg_kwargs: _amg.update(amg_kwargs) + print(_amg) model_for_amg = getattr(sam2_model_or_predictor, "model", sam2_model_or_predictor) mask_generator = SAM2AutomaticMaskGenerator(model=model_for_amg, **_amg) @@ -91,6 +92,8 @@ def automask_metrics( # AutoCast ac = autocast_ctx if autocast_ctx is not None else (lambda: nullcontext()) + print('TEST') + # Per-image loop per_image, abiou_vals = [], [] for img, gt_list in zip(images, gt_masks_per_image): diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py index ac954b8..1a88754 100644 --- a/saber/finetune/trainer.py +++ b/saber/finetune/trainer.py @@ -3,8 +3,8 @@ from saber.finetune.metrics import automask_metrics from saber.finetune.losses import MultiMaskIoULoss from lightning import fabric +import torch, os, optuna from tqdm import tqdm -import torch, os class SAM2FinetuneTrainer: def __init__(self, predictor, train_loader, val_loader): @@ -47,10 +47,25 @@ def _autocast(): # Initialize the loss function self.focal_alpha = 0.25 self.focal_gamma = 2.0 - self.supervise_all_iou = True + self.supervise_all_iou = False self.iou_use_l1_loss = False self.predict_multimask = True + # Automask Generator Parameters + self.amg_kwargs = dict( + points_per_side=32, + points_per_batch=128, + pred_iou_thresh=0.5, + stability_score_thresh=0.7, + stability_score_offset=0.0, + crop_n_layers=0, + crop_n_points_downscale_factor=2, + box_nms_thresh=0.9, + use_m2m=False, + multimask_output=True, + ) + self.nAMGtrials = 10 + # Initialize the use_boxes flag self.use_boxes = False @@ -154,7 +169,7 @@ def forward_step(self, batch): return prd_masks, prd_scores, gt_masks, inst_img_ix @torch.no_grad() - def validate_step(self): + def validate_step(self, amg_kwargs=None, max_images=float('inf'), reduce_all_ranks=True): """ Validate the model on the given batch. """ @@ -166,8 +181,12 @@ def validate_step(self): loss_sum = torch.tensor(0.0, device=self.device) n_imgs = torch.tensor(0.0, device=self.device) n_inst = torch.tensor(0.0, device=self.device) + + if amg_kwargs is None: + amg_kwargs = self.amg_kwargs # Each rank iterates only its shard (Fabric sets DistributedSampler for you) + num_images = 0 for batch in self.val_loader: # Compute Loss on decoder outputs @@ -191,15 +210,19 @@ def validate_step(self): top_k=20, device=self.device, autocast_ctx=self.autocast, + amg_kwargs=amg_kwargs, ) # Weight by number of images so we can average correctly later num = float(m["num_images"]) abiou_sum += torch.tensor(m["ABIoU"] * num, device=self.device) n_imgs += torch.tensor(num, device=self.device) + num_images += num + if num_images >= max_images: + break # Global reduction (sum across all ranks) - if self.use_fabric: + if self.use_fabric and reduce_all_ranks: loss_sum = self.fabric.all_reduce(loss_sum, reduce_op="sum") abiou_sum = self.fabric.all_reduce(abiou_sum, reduce_op="sum") n_imgs = self.fabric.all_reduce(n_imgs, reduce_op="sum") @@ -286,7 +309,10 @@ def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 10): self.scheduler.step() # Validate - metrics = self.validate_step() + if (epoch+1) % 50 == 0: + metrics = self.amg_param_tuner() + else: + metrics = self.validate_step() # Print Only on Rank 0 if self.is_global_zero: @@ -309,4 +335,79 @@ def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 10): best_metric_value = metrics[best_metric] ckpt = {"model": self.predictor.model.state_dict()} torch.save(ckpt, f"{self.save_dir}/best_model.pth") - print(f"Best {best_metric} saved!") \ No newline at end of file + print(f"Best {best_metric} saved!") + + def amg_param_tuner(self, n_trials=10): + """ + Tune a few AMG thresholds with Bayesian optimization (TPE). + Warm-start from the current self.amg_kwargs. + """ + + if self.use_fabric and not self.is_global_zero: + # Non-zero ranks: wait for rank 0 to finish tuning and broadcast params. + self.fabric.barrier() + # Receive updated dict from rank 0 + self.amg_kwargs = self.fabric.broadcast(self.amg_kwargs, src=0) + # Now run normal distributed validation so logs are comparable + return self.validate_step() + + + # Use a fixed sampler (seed for reproducibility) + study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=0)) + + # ---- Warm start with current params ---- + # IMPORTANT: names must match suggest_* names in objective + warm = { + "pred_iou_thresh": float(self.amg_kwargs["pred_iou_thresh"]), + "stability_score_thresh": float(self.amg_kwargs["stability_score_thresh"]), + "stability_score_offset": float(self.amg_kwargs["stability_score_offset"]), + } + study.enqueue_trial(warm) + + def objective(trial: optuna.Trial) -> float: + """ + Objective for Optuna: maximize ABIoU on a held-out validation set, + varying only a few AMG thresholds. Do NOT mutate self.amg_kwargs here. + """ + # Suggest in sensible finetuned ranges + pred_iou_thresh = trial.suggest_float("pred_iou_thresh", 0.40, 0.75) + stability_score_thresh = trial.suggest_float("stability_score_thresh", 0.55, 0.90) + stability_score_offset = trial.suggest_float("stability_score_offset", 0.00, 0.30) + + # Build a LOCAL kwargs dict (copy), keep other knobs fixed + amg_kwargs_trial = dict(self.amg_kwargs) + amg_kwargs_trial.update({ + "pred_iou_thresh": pred_iou_thresh, + "stability_score_thresh": stability_score_thresh, + "stability_score_offset": stability_score_offset, + }) + + # Validate the model with the new AMG thresholds + metrics = self.validate_step( + amg_kwargs_trial, max_images=100, + reduce_all_ranks=not self.use_fabric + ) + return metrics['ABIoU'] + + # Optimize + study.optimize(objective, n_trials=n_trials, show_progress_bar=False) + + # Update trainer state with the best params + best = study.best_params + self.amg_kwargs.update({ + "pred_iou_thresh": float(best["pred_iou_thresh"]), + "stability_score_thresh": float(best["stability_score_thresh"]), + "stability_score_offset": float(best["stability_score_offset"]), + }) + + # Let other ranks proceed and receive the dict + if self.use_fabric: + self.fabric.barrier() + self.amg_kwargs = self.fabric.broadcast(self.amg_kwargs, src=0) + + + if self.is_global_zero: + print("AMG tuned →", {k: self.amg_kwargs[k] for k in ["pred_iou_thresh","stability_score_thresh","stability_score_offset"]}) + + # Now run the normal distributed validate_step() so metrics are globally averaged + return self.validate_step() \ No newline at end of file From 322fbf060f9f74ffe4202ef89df8b9e687a12dab Mon Sep 17 00:00:00 2001 From: root Date: Wed, 24 Sep 2025 05:03:59 +0000 Subject: [PATCH 08/15] add augmentations --- saber/classifier/datasets/augment.py | 35 +++++++++++++++++++++------- saber/finetune/dataset.py | 28 ++++++++++++++-------- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/saber/classifier/datasets/augment.py b/saber/classifier/datasets/augment.py index 077c731..c339c20 100755 --- a/saber/classifier/datasets/augment.py +++ b/saber/classifier/datasets/augment.py @@ -2,10 +2,11 @@ Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd, RandRotate90d, RandFlipd, RandScaleIntensityd, RandShiftIntensityd, RandAdjustContrastd, RandGaussianNoised, RandAffined, RandomOrder, - RandGaussianSmoothd, + RandGaussianSmoothd, SqueezeDimd, ToNumpyd, EnsureTyped, ResizeD ) from saber.classifier.datasets.RandMaskCrop import AdaptiveCropd from torch.utils.data import random_split +import torch def get_preprocessing_transforms(random_translations=False): transforms = Compose([ @@ -20,13 +21,6 @@ def get_preprocessing_transforms(random_translations=False): def get_training_transforms(): train_transforms = Compose([ - # RandAffined( - # keys=["image", "mask"], - # prob=0.75, - # translate_range=(30, 30), - # padding_mode="border", - # mode=("bilinear", "nearest") - # ), RandomOrder([ RandRotate90d(keys=["image", "mask"], prob=0.5, spatial_axes=[0, 1]), RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=0), @@ -45,4 +39,27 @@ def get_validation_transforms(): def split_dataset(dataset, val_split=0.2): train_size = int(len(dataset) * (1 - val_split)) val_size = len(dataset) - train_size - return random_split(dataset, [train_size, val_size]) \ No newline at end of file + return random_split(dataset, [train_size, val_size]) + +def get_finetune_transforms(target_size=(1024,1024)): + transforms = Compose([ + EnsureChannelFirstd(keys=["image", "mask"], channel_dim="no_channel"), + EnsureTyped(keys=["image", "mask"], dtype=[torch.float32, torch.int64]), + ResizeD( + keys=["image", "mask"], + spatial_size=target_size, + mode=("bilinear", "nearest") + ), + RandomOrder([ + RandRotate90d(keys=["image", "mask"], prob=0.5, spatial_axes=[0, 1]), + RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=0), + RandScaleIntensityd(keys="image", prob=0.5, factors=(0.85, 1.15)), + RandShiftIntensityd(keys="image", prob=0.5, offsets=(-0.15, 0.15)), + RandAdjustContrastd(keys="image", prob=0.5, gamma=(0.85, 1.15)), + RandGaussianNoised(keys="image", prob=0.5, mean=0.0, std=1.5), + RandGaussianSmoothd(keys="image", prob=0.5, sigma_x=(0.25, 1.5), sigma_y=(0.25, 1.5)), + ]), + SqueezeDimd(keys=["image", "mask"], dim=0), + ToNumpyd(keys=["image", "mask"]), + ]) + return transforms \ No newline at end of file diff --git a/saber/finetune/dataset.py b/saber/finetune/dataset.py index f320e22..17757b8 100644 --- a/saber/finetune/dataset.py +++ b/saber/finetune/dataset.py @@ -1,11 +1,18 @@ from scipy.ndimage import binary_erosion, binary_dilation +from monai.transforms import Compose, EnsureChannelFirstd import saber.finetune.helper as helper -from saber.utils import preprocessing +from saber.utils import preprocessing from torch.utils.data import Dataset from tqdm import tqdm import numpy as np import zarr, torch +from monai.transforms import ( + Compose, EnsureChannelFirstd, RandRotate90d, RandFlipd, RandScaleIntensityd, + RandShiftIntensityd, RandAdjustContrastd, RandGaussianNoised, + RandomOrder, RandGaussianSmoothd, +) + class AutoMaskDataset(Dataset): def __init__(self, tomogram_zarr_path: str = None, @@ -168,11 +175,8 @@ def _get_tomogram_item(self, idx): image_slab = self.tomogram_store[key]['0'][z_start:z_end] seg_slab = self.tomogram_store[key]['labels/0'][z_start:z_end] - # Project slab and normalize - image_projection = preprocessing.project_tomogram(image_slab) - image_2d = preprocessing.proprocess(image_projection) # 3xHxW - - # Project segmentation + # Project slab and segmentation + image_2d = preprocessing.project_tomogram(image_slab) seg_2d = preprocessing.project_segmentation(seg_slab) # HxW return self._package_image_item(image_2d, seg_2d) @@ -302,7 +306,7 @@ def _package_image_item(self, - Emits only positive clicks + boxes (no negatives). Returns: { - "image": HxWx3 uint8, + "image": HxW, "masks": list[H x W] float32 in {0,1}, "points": list[#p x 2] float32 (xy), "labels": list[#p] float32 (all ones), @@ -312,9 +316,10 @@ def _package_image_item(self, # Apply transforms to image and segmentation if self.transform: - data = self.transform({'image': image_2d, 'masks': segmentation}) - image_2d, segmentation = data['image'], data['masks'] + sample = self.transform({'image': image_2d, 'mask': segmentation}) + image_2d, segmentation = sample['image'], sample['mask'] + # Get image and segmentation shapes h, w = segmentation.shape min_pixels = 0 # min_pixels = int(self.min_area * h * w) @@ -351,8 +356,11 @@ def _package_image_item(self, labels_t = [torch.from_numpy(np.ones((1,), dtype=np.float32))] boxes_t = [torch.from_numpy(np.array([0, 0, 1, 1], dtype=np.float32))] + # Normalize the Image + image_2d = preprocessing.proprocess(image_2d) # 3xHxW + return { - "image": image_2d, # HxWx3 uint8 + "image": image_2d, # HxWx3 "masks": masks_t, # list[H x W] float32 in {0,1} "points": points_t, # list[#p x 2] float32 (xy) "labels": labels_t, # list[#p] all ones From 8830e726ac9e9469d288ac5c514ae60160fc0b98 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 26 Sep 2025 01:10:25 +0000 Subject: [PATCH 09/15] make sure metrics is tighter --- saber/finetune/helper.py | 2 +- saber/finetune/metrics.py | 321 ++++++++++++++++++++++++------ saber/finetune/train.py | 15 +- saber/finetune/trainer.py | 2 +- saber/visualization/classifier.py | 78 +++++--- 5 files changed, 319 insertions(+), 99 deletions(-) diff --git a/saber/finetune/helper.py b/saber/finetune/helper.py index f57634d..716f686 100644 --- a/saber/finetune/helper.py +++ b/saber/finetune/helper.py @@ -156,7 +156,7 @@ def visualize_item_with_points(image, masks, points, boxes=None, add_masks(mstack, ax) # uses get_colors internally # color-match points (and boxes) to mask color - colors = get_colors() + colors = get_colors(len(masks)) for i, pts in enumerate(points): if pts is None or len(pts) == 0: continue diff --git a/saber/finetune/metrics.py b/saber/finetune/metrics.py index c7d9bc8..3657927 100644 --- a/saber/finetune/metrics.py +++ b/saber/finetune/metrics.py @@ -1,4 +1,5 @@ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from typing import List, Dict, Any, Optional, Callable, Union from contextlib import nullcontext import numpy as np import torch @@ -40,33 +41,197 @@ def _abiou(proposals, gts): best = iou.max(dim=1).values # [Ng] return best.mean() -# --------------------- ABIoU evaluator --------------------- +# --------------------- Utilities --------------------- + +def _to_bool_tensor(x: Union[np.ndarray, torch.Tensor], device: torch.device) -> torch.Tensor: + """ + Accepts HxW or [N,H,W] arrays/tensors, binarizes (>0) and returns bool tensor on device with shape [N,H,W]. + """ + if isinstance(x, np.ndarray): + arr = x + if arr.ndim == 2: + arr = arr[None, ...] + t = torch.from_numpy(arr) + elif isinstance(x, torch.Tensor): + t = x + if t.ndim == 2: + t = t.unsqueeze(0) + else: + raise TypeError(f"Unsupported mask type: {type(x)}") + # binarize and cast to bool + t = (t != 0) + return t.to(device=device, dtype=torch.bool) + + +def _downsample_bool_masks(m: torch.Tensor, factor: int) -> torch.Tensor: + """ + Downsample boolean masks by a small integer factor via max-pooling (keeps foreground coverage). + m: [N,H,W] bool + """ + if factor <= 1 or m.numel() == 0: + return m + # reshape for pooling + N, H, W = m.shape + H2 = H // factor + W2 = W // factor + if H2 == 0 or W2 == 0: + return m + # crop to divisible + m = m[:, :H2 * factor, :W2 * factor] + # convert to float for pooling-like reduction via unfold + mf = m.float() + mf = mf.unfold(1, factor, factor).unfold(2, factor, factor) # [N, H2, W2, f, f] + # max over the small window -> any(True) + mf = mf.contiguous().view(N, H2, W2, -1).max(dim=-1).values + return (mf > 0).to(dtype=torch.bool) + + +# --------------------- IoU (vectorized) --------------------- + +def _pairwise_iou_bool(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + a: [Na,H,W] bool, b: [Nb,H,W] bool -> IoU [Na,Nb] float32 + Vectorized via flatten + matmul. Stays on device. + """ + Na = a.shape[0] + Nb = b.shape[0] + if Na == 0 or Nb == 0: + return a.new_zeros((Na, Nb), dtype=torch.float32) + + # Flatten (reshape tolerates non-contiguous inputs) + a_f = a.reshape(Na, -1).float() + b_f = b.reshape(Nb, -1).float() + + # Areas and intersections + areas_a = a_f.sum(dim=1) # [Na] + areas_b = b_f.sum(dim=1) # [Nb] + inter = a_f @ b_f.t() # [Na,Nb] + + # Unions + union = areas_a[:, None] + areas_b[None, :] - inter + eps + return (inter / union).to(torch.float32) + + +# --------------------- Metrics --------------------- + +def abiou_original(proposals: torch.Tensor, gts: torch.Tensor) -> torch.Tensor: + """ + ABIoU as mean over GTs of max IoU to any proposal (allows proposal reuse). + proposals, gts: [N,H,W] bool (on same device) + """ + dev = gts.device if gts.numel() > 0 else (proposals.device if proposals.numel() > 0 else torch.device("cpu")) + if gts.numel() == 0 and proposals.numel() == 0: + return torch.tensor(1.0, device=dev, dtype=torch.float32) + if gts.numel() == 0 or proposals.numel() == 0: + return torch.tensor(0.0, device=dev, dtype=torch.float32) + + iou = _pairwise_iou_bool(gts, proposals) # [Ng,Np] + best = iou.max(dim=1).values # [Ng] + return best.mean() + + +def abiou_one_to_one_greedy(proposals: torch.Tensor, gts: torch.Tensor) -> torch.Tensor: + """ + ABIoU with one-to-one greedy matching (no proposal reuse). + proposals, gts: [N,H,W] bool + """ + dev = gts.device if gts.numel() > 0 else (proposals.device if proposals.numel() > 0 else torch.device("cpu")) + if gts.numel() == 0 and proposals.numel() == 0: + return torch.tensor(1.0, device=dev, dtype=torch.float32) + if gts.numel() == 0 or proposals.numel() == 0: + return torch.tensor(0.0, device=dev, dtype=torch.float32) + + iou = _pairwise_iou_bool(gts, proposals) # [Ng,Np] + Ng, Np = iou.shape + used_g = torch.zeros(Ng, dtype=torch.bool, device=dev) + used_p = torch.zeros(Np, dtype=torch.bool, device=dev) + + matched_sum = torch.tensor(0.0, device=dev) + # Greedy loop at most min(Ng, Np) steps + for _ in range(min(Ng, Np)): + # mask used rows/cols by setting to -1 + iou_masked = iou.clone() + if used_g.any(): + iou_masked[used_g, :] = -1 + if used_p.any(): + iou_masked[:, used_p] = -1 + val, idx = torch.max(iou_masked.view(-1), dim=0) + if val <= 0: + break + g_idx = idx // Np + p_idx = idx % Np + matched_sum = matched_sum + val + used_g[g_idx] = True + used_p[p_idx] = True + + # Average over ALL GTs (unmatched GTs count 0) + return matched_sum / max(Ng, 1) + + +def union_iou(proposals: torch.Tensor, gts: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + Pixel-set IoU between union of proposals and union of GTs. + proposals, gts: [N,H,W] bool + """ + dev = gts.device if gts.numel() > 0 else (proposals.device if proposals.numel() > 0 else torch.device("cpu")) + if gts.numel() == 0 and proposals.numel() == 0: + return torch.tensor(1.0, device=dev, dtype=torch.float32) + + if proposals.numel() > 0: + P = proposals.any(dim=0) + else: + # create zero map based on gts ref + H, W = gts.shape[-2], gts.shape[-1] + P = torch.zeros((H, W), dtype=torch.bool, device=dev) + + if gts.numel() > 0: + G = gts.any(dim=0) + else: + H, W = proposals.shape[-2], proposals.shape[-1] + G = torch.zeros((H, W), dtype=torch.bool, device=dev) + + inter = (P & G).sum().float() + uni = (P | G).sum().float() + eps + return (inter / uni).to(torch.float32) + + +# --------------------- Main evaluator --------------------- @torch.no_grad() def automask_metrics( - sam2_model_or_predictor, - images, - gt_masks_per_image, + sam2_model_or_predictor: Any, + images: List[Union[np.ndarray, torch.Tensor]], # HxW or HxWx3 (uint8 preferred) + gt_masks_per_image: List[List[Union[np.ndarray, torch.Tensor]]], # per-image list of HxW masks *, - amg_kwargs=None, - top_k=20, - device=None, - autocast_ctx=None, # callable -> context manager (e.g., trainer.autocast); if None, no autocast -): - """ - Run SAM2AutomaticMaskGenerator and compute ABIoU only. - - Args - ---- - sam2_model_or_predictor : SAM2 model or predictor (we’ll use .model if present) - images : list of images: HxW or HxWx3; uint8 NumPy preferred, but float is fine - gt_masks_per_image : list[list[H x W]]; elements may be NumPy or torch; non-zero = foreground - amg_kwargs : dict forwarded to SAM2AutomaticMaskGenerator - top_k : keep only top-K proposals per image by AMG score (optional) - device : torch device for tensors (defaults to 'cuda' if available) - autocast_ctx : callable returning a context manager for mixed precision during AMG forward only - """ - # Defaults for AMG (tweak as you like) + amg_kwargs: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = 20, + device: Optional[torch.device] = None, + autocast_ctx: Optional[Callable[[], Any]] = None, + downsample_factor: int = 1, + return_per_image: bool = False, +) -> Dict[str, Any]: + """ + Run SAM2AutomaticMaskGenerator once per image and compute: + - ABIoU_one_to_one (greedy, no reuse) + - UnionIoU + - ABIoU_original (optional reference) + + Speed features: + - Single IoU matrix fuels both ABIoUs. + - Everything stays on GPU; masks are boolean. + - Optional downsample_factor (e.g., 2 or 4) for huge speedups. + + Returns: + { + 'ABIoU_one_to_one': float, + 'UnionIoU': float, + 'ABIoU_original': float, + 'num_images': int, + 'per_image': [ ... ] # if return_per_image + } + """ + + # AMG defaults (safe, tweak as needed) _amg = dict( points_per_side=32, points_per_batch=128, @@ -81,72 +246,106 @@ def automask_metrics( ) if amg_kwargs: _amg.update(amg_kwargs) - print(_amg) + model_for_amg = getattr(sam2_model_or_predictor, "model", sam2_model_or_predictor) mask_generator = SAM2AutomaticMaskGenerator(model=model_for_amg, **_amg) - # Determine Device + # Device if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # AutoCast + # Autocast (AMG forward only) ac = autocast_ctx if autocast_ctx is not None else (lambda: nullcontext()) - print('TEST') + # Accumulators + one2one_vals, union_vals, abiou_orig_vals = [], [], [] + per_image_out = [] - # Per-image loop - per_image, abiou_vals = [], [] for img, gt_list in zip(images, gt_masks_per_image): - # -------- ensure AMG-friendly uint8 numpy image -------- + # ---- Ensure numpy uint8 image for AMG ---- if isinstance(img, torch.Tensor): img_np = img.detach().cpu().numpy() else: img_np = img H, W = img_np.shape[:2] - # -------- AMG forward under autocast (fast path) -------- + # ---- AMG forward ---- with ac(): - props = mask_generator.generate(img_np) # list of dicts + proposals = mask_generator.generate(img_np) # list of dict - if len(props) == 0: - prop_masks = torch.zeros((0, H, W), dtype=torch.uint8, device=device) + # ---- Convert proposals -> [Np,H,W] bool on device ---- + if len(proposals) == 0: + prop_masks = torch.zeros((0, H, W), dtype=torch.bool, device=device) else: - # Sort by predicted_iou/score and optionally keep top_k + # sort by predicted_iou (or score), keep top_k def _score(d): return float(d.get("predicted_iou", d.get("score", 0.0))) - props.sort(key=_score, reverse=True) + proposals.sort(key=_score, reverse=True) if top_k is not None and top_k > 0: - props = props[:top_k] + proposals = proposals[:top_k] + masks_np = [(p["segmentation"] > 0).astype(np.uint8) for p in proposals] + prop_masks = torch.from_numpy(np.stack(masks_np, axis=0)).to(device=device, dtype=torch.bool) + prop_masks = prop_masks.contiguous() - masks_np = [(p["segmentation"] > 0).astype(np.uint8) for p in props] - prop_masks = torch.from_numpy(np.stack(masks_np, axis=0)).to(device=device) - - # -------- stack GTs -------- + # ---- Convert GTs -> [Ng,H,W] bool on device ---- if len(gt_list) == 0: - gt_masks = torch.zeros((0, H, W), dtype=torch.uint8, device=device) + gt_masks = torch.zeros((0, H, W), dtype=torch.bool, device=device) else: - gts = [] + gt_bool = [] for g in gt_list: if isinstance(g, torch.Tensor): g_np = g.detach().cpu().numpy() else: g_np = g - gts.append((g_np > 0).astype(np.uint8)) - gt_masks = torch.from_numpy(np.stack(gts, axis=0)).to(device=device) - - # -------- ABIoU in float32 (stable) -------- - abiou = _abiou(prop_masks, gt_masks) - per_image.append({ - "ABIoU": float(abiou.detach().cpu()), - "num_props": int(prop_masks.shape[0]), - "num_gt": int(gt_masks.shape[0]), - }) - abiou_vals.append(abiou) - - ABIoU = torch.stack(abiou_vals).mean().item() if abiou_vals else 0.0 - return { - "ABIoU": ABIoU, - 'ABIoU_per_image': abiou_vals, - "num_images": len(per_image), - "per_image": per_image, + gt_bool.append((g_np > 0).astype(np.uint8)) + gt_masks = torch.from_numpy(np.stack(gt_bool, axis=0)).to(device=device, dtype=torch.bool) + gt_masks = gt_masks.contiguous() + + # ---- Optional downsample (max-pool style) ---- + if downsample_factor > 1: + prop_masks_ds = _downsample_bool_masks(prop_masks, downsample_factor) + gt_masks_ds = _downsample_bool_masks(gt_masks, downsample_factor) + else: + prop_masks_ds = prop_masks + gt_masks_ds = gt_masks + + # ---- Metrics (single IoU matrix shared under the hood) ---- + m_one2one = abiou_one_to_one_greedy(prop_masks_ds, gt_masks_ds) + m_union = union_iou(prop_masks_ds, gt_masks_ds) + m_orig = abiou_original(prop_masks_ds, gt_masks_ds) + + one2one_vals.append(m_one2one) + union_vals.append(m_union) + abiou_orig_vals.append(m_orig) + + if return_per_image: + per_image_out.append({ + "ABIoU": float(m_one2one.detach().cpu()), + "ABIoU_original": float(m_orig.detach().cpu()), + "num_props": int(prop_masks.shape[0]), + "num_gt": int(gt_masks.shape[0]), + "H": int(H), + "W": int(W), + }) + + # ---- Averages ---- + if len(one2one_vals) == 0: + return { + "ABIoU_one_to_one": 0.0, + "UnionIoU": 0.0, + "ABIoU_original": 0.0, + "num_images": 0, + "per_image": [], + } + + ABIoU_one_to_one = torch.stack(one2one_vals).mean().item() + ABIoU_original_avg = torch.stack(abiou_orig_vals).mean().item() + + out = { + "ABIoU": ABIoU_one_to_one, + "ABIoU_original": ABIoU_original_avg, + "num_images": len(one2one_vals), } + if return_per_image: + out["per_image"] = per_image_out + return out diff --git a/saber/finetune/train.py b/saber/finetune/train.py index 36dd3e4..478addb 100644 --- a/saber/finetune/train.py +++ b/saber/finetune/train.py @@ -1,3 +1,4 @@ +from saber.classifier.datasets.augment import get_finetune_transforms from sam2.sam2_image_predictor import SAM2ImagePredictor from saber.finetune.trainer import SAM2FinetuneTrainer from saber.finetune.dataset import AutoMaskDataset @@ -31,10 +32,16 @@ def finetune_sam2( predictor.model.sam_prompt_encoder.train(True) # Load data loaders - train_loader = DataLoader( AutoMaskDataset(tomo_train, fib_train), batch_size=batch_size, shuffle=True, - num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) - val_loader = DataLoader( AutoMaskDataset(tomo_val, fib_val), batch_size=batch_size, shuffle=False, - num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) if (tomo_val or fib_val) else train_loader + train_loader = DataLoader( AutoMaskDataset( + tomo_train, fib_train, transform=get_finetune_transforms(), + batch_size=batch_size, shuffle=True, + num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) + ) + val_loader = DataLoader( AutoMaskDataset( + tomo_val, fib_val, + batch_size=batch_size, shuffle=False, + num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) + ) if (tomo_val or fib_val) else train_loader # Initialize trainer and train trainer = SAM2FinetuneTrainer( predictor, train_loader, val_loader ) diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py index 1a88754..67ec37e 100644 --- a/saber/finetune/trainer.py +++ b/saber/finetune/trainer.py @@ -309,7 +309,7 @@ def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 10): self.scheduler.step() # Validate - if (epoch+1) % 50 == 0: + if (epoch+1) % 500 == 0: metrics = self.amg_param_tuner() else: metrics = self.validate_step() diff --git a/saber/visualization/classifier.py b/saber/visualization/classifier.py index 8445a4e..149198a 100755 --- a/saber/visualization/classifier.py +++ b/saber/visualization/classifier.py @@ -1,5 +1,5 @@ +from matplotlib.colors import ListedColormap, hsv_to_rgb from matplotlib.widgets import TextBox, Button -from matplotlib.colors import ListedColormap import matplotlib.pyplot as plt import numpy as np @@ -29,7 +29,8 @@ def display_mask_list(image: np.ndarray, masks: list, save_button: bool = False) display_mask_array(image, masks, save_button) def display_mask_array(image: np.ndarray, masks: np.ndarray, save_button: bool = False): - colors = get_colors() + n_labels = int(np.max(masks)) + colors = get_colors(n_needed=n_labels) # Create figure with extra space for widgets fig = plt.figure(figsize=(9, 7)) @@ -38,7 +39,7 @@ def display_mask_array(image: np.ndarray, masks: np.ndarray, save_button: bool = ax_img = plt.axes([0.1, 0.2, 0.8, 0.75]) ax_img.imshow(image, cmap='gray') - cmap_colors = [(1, 1, 1, 0)] + colors[:np.max(masks)] + cmap_colors = [(1, 1, 1, 0)] + colors cmap = ListedColormap(cmap_colors) ax_img.imshow(masks, cmap=cmap, alpha=0.6) ax_img.axis('off') @@ -303,45 +304,58 @@ def plot_per_class_metrics(per_class_results, save_path=None): else: plt.show() -def get_colors(): +def get_colors(n_needed=None, alpha=0.5): # Extended vibrant color palette - colors = [ - (0, 1, 1, 0.5), # Cyan (bright, high contrast) - (1, 0, 1, 0.5), # Magenta - (0, 0, 1, 0.5), # Blue - (0, 1, 0, 0.5), # Green - - (1, 0.5, 0, 0.5), # Orange - (0.5, 0, 0.5, 0.5), # Purple - (0.2, 0.6, 0.9, 0.5), # Sky Blue - (0.9, 0.2, 0.6, 0.5), # Hot Pink - (0.6, 0.2, 0.8, 0.5), # Violet - - (0.4, 0.7, 0.2, 0.5), # Lime - (0.8, 0.4, 0, 0.5), # Burnt Orange - (0, 0.5, 0, 0.5), # Dark Green - (0.7, 0.3, 0.6, 0.5), # Orchid - (0.9, 0.6, 0.2, 0.5), # Gold - - (1, 1, 0.3, 0.5), # Yellow - (0.5, 0.5, 0, 0.5), # Olive - (0, 0, 0.5, 0.5), # Navy - (0.5, 0, 0, 0.5), # Maroon + base = [ + (0, 1, 1, alpha), # Cyan (bright, high contrast) + (1, 0, 1, alpha), # Magenta + (0, 0, 1, alpha), # Blue + (0, 1, 0, alpha), # Green + + (1, 0.5, 0, alpha), # Orange + (0.5, 0, 0.5, alpha), # Purple + (0.2, 0.6, 0.9, alpha), # Sky Blue + (0.9, 0.2, 0.6, alpha), # Hot Pink + (0.6, 0.2, 0.8, alpha), # Violet + + (0.4, 0.7, 0.2, alpha), # Lime + (0.8, 0.4, 0, alpha), # Burnt Orange + (0, 0.5, 0, alpha), # Dark Green + (0.7, 0.3, 0.6, alpha), # Orchid + (0.9, 0.6, 0.2, alpha), # Gold + + (1, 1, 0.3, alpha), # Yellow + (0.5, 0.5, 0, alpha), # Olive + (0, 0, 0.5, alpha), # Navy + (0.5, 0, 0, alpha), # Maroon # Pastel shades (can be used for less prominent classes) - (1, 0.7, 0.7, 0.5), # Light Red/Pink - (0.7, 1, 0.7, 0.5), # Light Green - (0.7, 0.7, 1, 0.5), # Light Blue - (1, 1, 0.7, 0.5), # Light Yellow + (1, 0.7, 0.7, alpha), # Light Red/Pink + (0.7, 1, 0.7, alpha), # Light Green + (0.7, 0.7, 1, alpha), # Light Blue + (1, 1, 0.7, alpha), # Light Yellow ] - return colors + if n_needed is None: + return base + n_needed = int(n_needed) + if n_needed <= len(base): + return base[:n_needed] + + extra_needed = n_needed - len(base) + hues = np.linspace(0, 1, extra_needed, endpoint=False) + extra = [tuple(hsv_to_rgb((h, 0.9, 0.9))) + (alpha,) for h in hues] + + return base + extra def add_masks(masks, ax): + # Get number of masks + num_masks = masks.shape[0] + # Get colors - colors = get_colors() + colors = get_colors(n_needed=num_masks) # Get number of masks num_masks = masks.shape[0] From b85408b960a517dc7cf6f8e864d364d34c76eddb Mon Sep 17 00:00:00 2001 From: root Date: Sat, 27 Sep 2025 17:25:05 +0000 Subject: [PATCH 10/15] continue making progress on finetuning --- saber/finetune/abiou.py | 351 ++++++++++++++++++++++++ saber/finetune/dataset.py | 73 +++-- saber/finetune/helper.py | 22 +- saber/finetune/losses.py | 73 +++-- saber/finetune/metrics.py | 557 ++++++++++++++++---------------------- saber/finetune/train.py | 21 +- saber/finetune/trainer.py | 251 +++++++++++------ 7 files changed, 872 insertions(+), 476 deletions(-) create mode 100644 saber/finetune/abiou.py diff --git a/saber/finetune/abiou.py b/saber/finetune/abiou.py new file mode 100644 index 0000000..3657927 --- /dev/null +++ b/saber/finetune/abiou.py @@ -0,0 +1,351 @@ +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from typing import List, Dict, Any, Optional, Callable, Union +from contextlib import nullcontext +import numpy as np +import torch + +# --------------------- IoU / ABIoU --------------------- + +def _mask_iou(a, b, eps=1e-6): + """ + a: [Na,H,W] {0,1}; b: [Nb,H,W] {0,1} -> IoU [Na,Nb] + """ + if a.numel() == 0 or b.numel() == 0: + dev = a.device if a.numel() > 0 else (b.device if b.numel() > 0 else torch.device("cpu")) + Na = a.shape[0] if a.numel() > 0 else 0 + Nb = b.shape[0] if b.numel() > 0 else 0 + return torch.zeros((Na, Nb), device=dev, dtype=torch.float32) + + a = a.float() + b = b.float() + inter = torch.einsum("nhw,mhw->nm", a, b) # [Na,Nb] + + ua = a.sum(dim=(1,2))[:, None] # [Na,1] + ub = b.sum(dim=(1,2))[None, :] # [1,Nb] + + union = ua + ub - inter + eps # [Na,Nb] + return inter / union + +def _abiou(proposals, gts): + """ + Average Best IoU (coverage metric). + proposals: [Np,H,W] {0,1}, gts: [Ng,H,W] {0,1} + """ + if gts.numel() == 0 and proposals.numel() == 0: + dev = proposals.device if proposals.numel() > 0 else gts.device + return torch.tensor(1.0, device=dev, dtype=torch.float32) + if gts.numel() == 0 or proposals.numel() == 0: + dev = proposals.device if proposals.numel() > 0 else gts.device + return torch.tensor(0.0, device=dev, dtype=torch.float32) + iou = _mask_iou(gts, proposals) # [Ng,Np] + best = iou.max(dim=1).values # [Ng] + return best.mean() + +# --------------------- Utilities --------------------- + +def _to_bool_tensor(x: Union[np.ndarray, torch.Tensor], device: torch.device) -> torch.Tensor: + """ + Accepts HxW or [N,H,W] arrays/tensors, binarizes (>0) and returns bool tensor on device with shape [N,H,W]. + """ + if isinstance(x, np.ndarray): + arr = x + if arr.ndim == 2: + arr = arr[None, ...] + t = torch.from_numpy(arr) + elif isinstance(x, torch.Tensor): + t = x + if t.ndim == 2: + t = t.unsqueeze(0) + else: + raise TypeError(f"Unsupported mask type: {type(x)}") + # binarize and cast to bool + t = (t != 0) + return t.to(device=device, dtype=torch.bool) + + +def _downsample_bool_masks(m: torch.Tensor, factor: int) -> torch.Tensor: + """ + Downsample boolean masks by a small integer factor via max-pooling (keeps foreground coverage). + m: [N,H,W] bool + """ + if factor <= 1 or m.numel() == 0: + return m + # reshape for pooling + N, H, W = m.shape + H2 = H // factor + W2 = W // factor + if H2 == 0 or W2 == 0: + return m + # crop to divisible + m = m[:, :H2 * factor, :W2 * factor] + # convert to float for pooling-like reduction via unfold + mf = m.float() + mf = mf.unfold(1, factor, factor).unfold(2, factor, factor) # [N, H2, W2, f, f] + # max over the small window -> any(True) + mf = mf.contiguous().view(N, H2, W2, -1).max(dim=-1).values + return (mf > 0).to(dtype=torch.bool) + + +# --------------------- IoU (vectorized) --------------------- + +def _pairwise_iou_bool(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + a: [Na,H,W] bool, b: [Nb,H,W] bool -> IoU [Na,Nb] float32 + Vectorized via flatten + matmul. Stays on device. + """ + Na = a.shape[0] + Nb = b.shape[0] + if Na == 0 or Nb == 0: + return a.new_zeros((Na, Nb), dtype=torch.float32) + + # Flatten (reshape tolerates non-contiguous inputs) + a_f = a.reshape(Na, -1).float() + b_f = b.reshape(Nb, -1).float() + + # Areas and intersections + areas_a = a_f.sum(dim=1) # [Na] + areas_b = b_f.sum(dim=1) # [Nb] + inter = a_f @ b_f.t() # [Na,Nb] + + # Unions + union = areas_a[:, None] + areas_b[None, :] - inter + eps + return (inter / union).to(torch.float32) + + +# --------------------- Metrics --------------------- + +def abiou_original(proposals: torch.Tensor, gts: torch.Tensor) -> torch.Tensor: + """ + ABIoU as mean over GTs of max IoU to any proposal (allows proposal reuse). + proposals, gts: [N,H,W] bool (on same device) + """ + dev = gts.device if gts.numel() > 0 else (proposals.device if proposals.numel() > 0 else torch.device("cpu")) + if gts.numel() == 0 and proposals.numel() == 0: + return torch.tensor(1.0, device=dev, dtype=torch.float32) + if gts.numel() == 0 or proposals.numel() == 0: + return torch.tensor(0.0, device=dev, dtype=torch.float32) + + iou = _pairwise_iou_bool(gts, proposals) # [Ng,Np] + best = iou.max(dim=1).values # [Ng] + return best.mean() + + +def abiou_one_to_one_greedy(proposals: torch.Tensor, gts: torch.Tensor) -> torch.Tensor: + """ + ABIoU with one-to-one greedy matching (no proposal reuse). + proposals, gts: [N,H,W] bool + """ + dev = gts.device if gts.numel() > 0 else (proposals.device if proposals.numel() > 0 else torch.device("cpu")) + if gts.numel() == 0 and proposals.numel() == 0: + return torch.tensor(1.0, device=dev, dtype=torch.float32) + if gts.numel() == 0 or proposals.numel() == 0: + return torch.tensor(0.0, device=dev, dtype=torch.float32) + + iou = _pairwise_iou_bool(gts, proposals) # [Ng,Np] + Ng, Np = iou.shape + used_g = torch.zeros(Ng, dtype=torch.bool, device=dev) + used_p = torch.zeros(Np, dtype=torch.bool, device=dev) + + matched_sum = torch.tensor(0.0, device=dev) + # Greedy loop at most min(Ng, Np) steps + for _ in range(min(Ng, Np)): + # mask used rows/cols by setting to -1 + iou_masked = iou.clone() + if used_g.any(): + iou_masked[used_g, :] = -1 + if used_p.any(): + iou_masked[:, used_p] = -1 + val, idx = torch.max(iou_masked.view(-1), dim=0) + if val <= 0: + break + g_idx = idx // Np + p_idx = idx % Np + matched_sum = matched_sum + val + used_g[g_idx] = True + used_p[p_idx] = True + + # Average over ALL GTs (unmatched GTs count 0) + return matched_sum / max(Ng, 1) + + +def union_iou(proposals: torch.Tensor, gts: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + Pixel-set IoU between union of proposals and union of GTs. + proposals, gts: [N,H,W] bool + """ + dev = gts.device if gts.numel() > 0 else (proposals.device if proposals.numel() > 0 else torch.device("cpu")) + if gts.numel() == 0 and proposals.numel() == 0: + return torch.tensor(1.0, device=dev, dtype=torch.float32) + + if proposals.numel() > 0: + P = proposals.any(dim=0) + else: + # create zero map based on gts ref + H, W = gts.shape[-2], gts.shape[-1] + P = torch.zeros((H, W), dtype=torch.bool, device=dev) + + if gts.numel() > 0: + G = gts.any(dim=0) + else: + H, W = proposals.shape[-2], proposals.shape[-1] + G = torch.zeros((H, W), dtype=torch.bool, device=dev) + + inter = (P & G).sum().float() + uni = (P | G).sum().float() + eps + return (inter / uni).to(torch.float32) + + +# --------------------- Main evaluator --------------------- + +@torch.no_grad() +def automask_metrics( + sam2_model_or_predictor: Any, + images: List[Union[np.ndarray, torch.Tensor]], # HxW or HxWx3 (uint8 preferred) + gt_masks_per_image: List[List[Union[np.ndarray, torch.Tensor]]], # per-image list of HxW masks + *, + amg_kwargs: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = 20, + device: Optional[torch.device] = None, + autocast_ctx: Optional[Callable[[], Any]] = None, + downsample_factor: int = 1, + return_per_image: bool = False, +) -> Dict[str, Any]: + """ + Run SAM2AutomaticMaskGenerator once per image and compute: + - ABIoU_one_to_one (greedy, no reuse) + - UnionIoU + - ABIoU_original (optional reference) + + Speed features: + - Single IoU matrix fuels both ABIoUs. + - Everything stays on GPU; masks are boolean. + - Optional downsample_factor (e.g., 2 or 4) for huge speedups. + + Returns: + { + 'ABIoU_one_to_one': float, + 'UnionIoU': float, + 'ABIoU_original': float, + 'num_images': int, + 'per_image': [ ... ] # if return_per_image + } + """ + + # AMG defaults (safe, tweak as needed) + _amg = dict( + points_per_side=32, + points_per_batch=128, + pred_iou_thresh=0.7, + stability_score_thresh=0.92, + stability_score_offset=0.7, + crop_n_layers=1, + crop_n_points_downscale_factor=2, + box_nms_thresh=0.7, + use_m2m=False, + multimask_output=True, + ) + if amg_kwargs: + _amg.update(amg_kwargs) + + model_for_amg = getattr(sam2_model_or_predictor, "model", sam2_model_or_predictor) + mask_generator = SAM2AutomaticMaskGenerator(model=model_for_amg, **_amg) + + # Device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Autocast (AMG forward only) + ac = autocast_ctx if autocast_ctx is not None else (lambda: nullcontext()) + + # Accumulators + one2one_vals, union_vals, abiou_orig_vals = [], [], [] + per_image_out = [] + + for img, gt_list in zip(images, gt_masks_per_image): + # ---- Ensure numpy uint8 image for AMG ---- + if isinstance(img, torch.Tensor): + img_np = img.detach().cpu().numpy() + else: + img_np = img + H, W = img_np.shape[:2] + + # ---- AMG forward ---- + with ac(): + proposals = mask_generator.generate(img_np) # list of dict + + # ---- Convert proposals -> [Np,H,W] bool on device ---- + if len(proposals) == 0: + prop_masks = torch.zeros((0, H, W), dtype=torch.bool, device=device) + else: + # sort by predicted_iou (or score), keep top_k + def _score(d): + return float(d.get("predicted_iou", d.get("score", 0.0))) + proposals.sort(key=_score, reverse=True) + if top_k is not None and top_k > 0: + proposals = proposals[:top_k] + masks_np = [(p["segmentation"] > 0).astype(np.uint8) for p in proposals] + prop_masks = torch.from_numpy(np.stack(masks_np, axis=0)).to(device=device, dtype=torch.bool) + prop_masks = prop_masks.contiguous() + + # ---- Convert GTs -> [Ng,H,W] bool on device ---- + if len(gt_list) == 0: + gt_masks = torch.zeros((0, H, W), dtype=torch.bool, device=device) + else: + gt_bool = [] + for g in gt_list: + if isinstance(g, torch.Tensor): + g_np = g.detach().cpu().numpy() + else: + g_np = g + gt_bool.append((g_np > 0).astype(np.uint8)) + gt_masks = torch.from_numpy(np.stack(gt_bool, axis=0)).to(device=device, dtype=torch.bool) + gt_masks = gt_masks.contiguous() + + # ---- Optional downsample (max-pool style) ---- + if downsample_factor > 1: + prop_masks_ds = _downsample_bool_masks(prop_masks, downsample_factor) + gt_masks_ds = _downsample_bool_masks(gt_masks, downsample_factor) + else: + prop_masks_ds = prop_masks + gt_masks_ds = gt_masks + + # ---- Metrics (single IoU matrix shared under the hood) ---- + m_one2one = abiou_one_to_one_greedy(prop_masks_ds, gt_masks_ds) + m_union = union_iou(prop_masks_ds, gt_masks_ds) + m_orig = abiou_original(prop_masks_ds, gt_masks_ds) + + one2one_vals.append(m_one2one) + union_vals.append(m_union) + abiou_orig_vals.append(m_orig) + + if return_per_image: + per_image_out.append({ + "ABIoU": float(m_one2one.detach().cpu()), + "ABIoU_original": float(m_orig.detach().cpu()), + "num_props": int(prop_masks.shape[0]), + "num_gt": int(gt_masks.shape[0]), + "H": int(H), + "W": int(W), + }) + + # ---- Averages ---- + if len(one2one_vals) == 0: + return { + "ABIoU_one_to_one": 0.0, + "UnionIoU": 0.0, + "ABIoU_original": 0.0, + "num_images": 0, + "per_image": [], + } + + ABIoU_one_to_one = torch.stack(one2one_vals).mean().item() + ABIoU_original_avg = torch.stack(abiou_orig_vals).mean().item() + + out = { + "ABIoU": ABIoU_one_to_one, + "ABIoU_original": ABIoU_original_avg, + "num_images": len(one2one_vals), + } + if return_per_image: + out["per_image"] = per_image_out + return out diff --git a/saber/finetune/dataset.py b/saber/finetune/dataset.py index 17757b8..f168cc9 100644 --- a/saber/finetune/dataset.py +++ b/saber/finetune/dataset.py @@ -3,15 +3,9 @@ import saber.finetune.helper as helper from saber.utils import preprocessing from torch.utils.data import Dataset +import zarr, torch, random from tqdm import tqdm import numpy as np -import zarr, torch - -from monai.transforms import ( - Compose, EnsureChannelFirstd, RandRotate90d, RandFlipd, RandScaleIntensityd, - RandShiftIntensityd, RandAdjustContrastd, RandGaussianNoised, - RandomOrder, RandGaussianSmoothd, -) class AutoMaskDataset(Dataset): def __init__(self, @@ -43,6 +37,7 @@ def __init__(self, self.k_min = 50 self.k_max = 100 self.transform = transform + self.keep_fraction = 0.5 # Check if both data types are available if tomogram_zarr_path is None and fib_zarr_path is None: @@ -88,12 +83,19 @@ def __init__(self, self.seed = seed self._rng = np.random.RandomState(seed) - # Resample epoch - self.resample_epoch() # Verbose Flag self.verbose = False + # Samples + self.tomogram_samples = [] + self.fib_samples = [] + self._prev_tomogram_samples = [] + self._prev_fib_samples = [] + + # First sampling epoch + self.resample_epoch() + def _estimate_zrange(self, key, band=(0.3, 0.7), threshold=0): """ Returns (z_min, z_max) inclusive bounds for valid slab centers @@ -113,45 +115,82 @@ def _estimate_zrange(self, key, band=(0.3, 0.7), threshold=0): def resample_epoch(self): """ Generate new random samples for this epoch """ - self.tomogram_samples = [] - self.fib_samples = [] # Sample random slabs from each tomogram if self.has_tomogram: print(f"Re-Sampling {self.slabs_per_volume_per_epoch} slabs from {self.n_tomogram_volumes} tomograms") + new_tomo_samples = [] for vol_idx in range(self.n_tomogram_volumes): # Sample random z positions from this tomogram volume z_min, z_max = self.tomo_shapes[vol_idx] - z_positions = np.random.randint( + z_positions = self._rng.randint( z_min, z_max, size=self.slabs_per_volume_per_epoch ) # Add to samples for z_pos in z_positions: - self.tomogram_samples.append((vol_idx, z_pos)) - np.random.shuffle(self.tomogram_samples) # Shuffle samples + new_tomo_samples.append((vol_idx, z_pos)) + + self.tomogram_samples = self._update_samples(self.tomogram_samples, new_tomo_samples) + + # Shuffle samples + self._rng.shuffle(self.tomogram_samples) # Sample random slices from each FIB volume if self.has_fib: print(f"Re-Sampling {self.slices_per_fib_per_epoch} slices from {self.n_fib_volumes} FIB volumes") + new_fib_samples = [] for fib_idx in range(self.n_fib_volumes): fib_shape = self.fib_shapes[fib_idx] # Sample random z positions from this FIB volume - z_positions = np.random.randint( + z_positions = self._rng.randint( 0, fib_shape[0], size=self.slices_per_fib_per_epoch ) for z_pos in z_positions: - self.fib_samples.append((fib_idx, z_pos)) - np.random.shuffle(self.fib_samples) # Shuffle samples + new_fib_samples.append((fib_idx, z_pos)) + + self.fib_samples = self._update_samples(self.fib_samples, new_fib_samples) + self._rng.shuffle(self.fib_samples) # Shuffle samples # Set epoch length self.epoch_length = len(self.tomogram_samples) + len(self.fib_samples) + def _update_samples(self, old, new): + """ + Return a mixed list with size == len(new): + - keep = min(round(len(new)*keep_fraction), len(old)) from 'old' + - add = len(new) - keep from 'new' + """ + target = len(new) + if target == 0: + return [] + + # choose keep set from *old* list, new set from *new* list + keep = min(int(round(target * self.keep_fraction)), len(old)) + add = target - keep + + # keep set from *old* list + if keep > 0: + keep_idx = self._rng.choice(len(old), size=keep, replace=False) + kept = [old[i] for i in keep_idx] + else: + kept = [] + + # add set from *new* list + if add > 0: + new_idx = self._rng.choice(len(new), size=add, replace=False) + added = [new[i] for i in new_idx] + else: + added = [] + + # return mixed list + return kept + added + def __len__(self): return self.epoch_length diff --git a/saber/finetune/helper.py b/saber/finetune/helper.py index 716f686..f8a9597 100644 --- a/saber/finetune/helper.py +++ b/saber/finetune/helper.py @@ -184,22 +184,32 @@ def visualize_item_with_points(image, masks, points, boxes=None, ax.axis("off") plt.tight_layout() -def save_training_log(results, outdir="results"): +def save_training_log(results, outdir="results", metric_keys=["ABIoU"]): # CSV (epoch-aligned, pad with blanks if needed) path = os.path.join(outdir, "metrics.csv") is_new = not os.path.exists(path) with open(path, "a", newline="") as f: - writer = csv.DictWriter(f, fieldnames=["epoch", "lr", "train_loss", "val_loss", "ABIoU"]) + writer = csv.DictWriter(f, + fieldnames=["epoch", "lr_mask", "lr_prompt", "train_loss", "val_loss", + *metric_keys, "train_iou", "train_dice", "train_mask", "val_iou", + "val_dice", "val_mask"]) if is_new: writer.writeheader() writer.writerow({ "epoch": int(results['epoch']), - "lr": f"{results['lr']:.1e}", - "train_loss": float(results['train']['loss']), - "val_loss": float(results['loss']), - "ABIoU": float(results['ABIoU']), + "lr_mask": f"{results['lr_mask']:.1e}", + "lr_prompt": f"{results['lr_prompt']:.1e}", + "train_loss": float(results['train']['loss_total']), + "val_loss": float(results['val']['loss_total']), + **{k: float(results['val'][k]) for k in metric_keys}, + "train_iou": float(results['train']['loss_iou']), + "train_dice": float(results['train']['loss_dice']), + "train_mask": float(results['train']['loss_mask']), + "val_iou": float(results['val']['loss_iou']), + "val_dice": float(results['val']['loss_dice']), + "val_mask": float(results['val']['loss_mask']), }) ######################################################################################## diff --git a/saber/finetune/losses.py b/saber/finetune/losses.py index f41cd63..0818c69 100644 --- a/saber/finetune/losses.py +++ b/saber/finetune/losses.py @@ -1,6 +1,6 @@ -import torch -import torch.nn as nn import torch.nn.functional as F +import torch.nn as nn +import torch def dice_loss_from_logits(logits, targets, eps=1e-6): probs = torch.sigmoid(logits) @@ -42,52 +42,65 @@ def forward(self, prd_masks, prd_scores, gt_masks): Args ---- prd_masks: [N, K, H, W] logits from decoder - prd_scores: [N, K] predicted IoU scores + prd_scores: [N, K] predicted IoU logits (will be sigmoided here) gt_masks: [N, H, W] float {0,1} """ + device = prd_masks.device N, K, H, W = prd_masks.shape - gt_masks = gt_masks.to(prd_masks.dtype) + gt_masks = gt_masks.to(prd_masks.dtype) # [N,H,W] - # ---- 1) Choose proposal by *true IoU* (no grad) ------------------------- + # ---- Compute hard predictions and true IoU per proposal (no grad) ---------- with torch.no_grad(): - probs_k = prd_masks.sigmoid() # [N,K,H,W] - pred_bin_k = (probs_k > 0.5).to(gt_masks.dtype) # hard masks - gt_k = gt_masks[:, None].expand_as(pred_bin_k) # [N,K,H,W] - inter = (pred_bin_k * gt_k).sum(dim=(2, 3)) # [N,K] - union = (pred_bin_k + gt_k - pred_bin_k * gt_k)\ - .sum(dim=(2, 3)).clamp_min(1e-6) # [N,K] - true_iou_k = inter / union # [N,K] - best_idx = true_iou_k.argmax(dim=1) # [N] + # probs_k = prd_masks.sigmoid() # [N,K,H,W] + # pred_bin = (probs_k > 0.0).to(gt_masks.dtype) # [N,K,H,W] + pred_bin = (prd_masks > 0.0).to(gt_masks.dtype) # [N,K,H,W] + gt_k = gt_masks[:, None].expand_as(pred_bin) # [N,K,H,W] + inter = (pred_bin * gt_k).sum(dim=(2, 3)) # [N,K] + union = (pred_bin + gt_k - pred_bin * gt_k).sum(dim=(2, 3)).clamp_min(1e-6) + true_iou_k = inter / union # [N,K] in [0,1] + + # ---- Per-proposal segmentation loss (focal + dice), select argmin ---------- + gt_rep = gt_masks.repeat_interleave(K, dim=0) # [N*K,H,W] + focal_per_k = focal_loss_from_logits( + prd_masks.view(N*K, H, W), gt_rep, + alpha=self.focal_alpha, gamma=self.focal_gamma + ).view(N, K) # [N,K] + dice_per_k = dice_loss_from_logits( + prd_masks.view(N*K, H, W), gt_rep + ).view(N, K) # [N,K] + + seg_loss_per_k = focal_per_k + dice_per_k # [N,K] + best_idx = seg_loss_per_k.argmin(dim=1) # [N] choose lowest seg loss row = torch.arange(N, device=device) - logits_star = prd_masks[row, best_idx] # [N,H,W] - score_star = prd_scores[row, best_idx] # [N] - true_iou_star = true_iou_k[row, best_idx].detach() # [N] + logits_star = prd_masks[row, best_idx] # [N,H,W] + true_iou_star = true_iou_k[row, best_idx].detach() # [N] - # ---- 2) Mask losses on the chosen proposal ------------------------------ + # ---- Segmentation losses on the chosen proposal ---------------------------- l_focal = focal_loss_from_logits( logits_star, gt_masks, alpha=self.focal_alpha, gamma=self.focal_gamma - ).mean() # scalar + ).mean() # scalar + l_dice = dice_loss_from_logits(logits_star, gt_masks).mean() - l_dice = dice_loss_from_logits(logits_star, gt_masks).mean() # scalar + # ---- IoU head regression (sigmoid + L1 by default) ------------------------- + pred_iou = prd_scores.sigmoid() # [N,K] in [0,1] - # ---- 3) IoU head regression on the chosen proposal ---------------------- if self.iou_use_l1_loss: - l_iou = F.smooth_l1_loss(score_star, true_iou_star) + l_iou = F.l1_loss(pred_iou[row, best_idx], true_iou_star) else: - l_iou = F.mse_loss(score_star, true_iou_star) + l_iou = F.mse_loss(pred_iou[row, best_idx], true_iou_star) - # (optional) small regularizer on *all* proposals to stabilize ranking + # Optional: supervise IoU for *all* proposals with small weight if self.supervise_all_iou: if self.iou_use_l1_loss: - l_iou_all = F.smooth_l1_loss(prd_scores, true_iou_k.detach()) + l_iou_all = F.l1_loss(pred_iou, true_iou_k.detach()) else: - l_iou_all = F.mse_loss(prd_scores, true_iou_k.detach()) - l_iou = l_iou + 0.1 * l_iou_all # small weight; tune 0.05–0.2 + l_iou_all = F.mse_loss(pred_iou, true_iou_k.detach()) + l_iou = l_iou + 0.1 * l_iou_all # tune 0.05–0.2 if needed - # ---- 4) Weighted sum ----------------------------------------------------- + # ---- Weighted sum ---------------------------------------------------------- loss_mask = l_focal loss_dice = l_dice loss_iou = l_iou @@ -97,8 +110,8 @@ def forward(self, prd_masks, prd_scores, gt_masks): self.weight_dict["loss_iou"] * loss_iou) return { - "loss_mask": loss_mask, - "loss_dice": loss_dice, - "loss_iou": loss_iou, + "loss_mask": loss_mask, + "loss_dice": loss_dice, + "loss_iou": loss_iou, "loss_total": total_loss, } \ No newline at end of file diff --git a/saber/finetune/metrics.py b/saber/finetune/metrics.py index 3657927..9752551 100644 --- a/saber/finetune/metrics.py +++ b/saber/finetune/metrics.py @@ -1,351 +1,246 @@ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator -from typing import List, Dict, Any, Optional, Callable, Union -from contextlib import nullcontext +from typing import Dict, Any, List, Tuple +import torch.nn.functional as F import numpy as np import torch -# --------------------- IoU / ABIoU --------------------- - -def _mask_iou(a, b, eps=1e-6): - """ - a: [Na,H,W] {0,1}; b: [Nb,H,W] {0,1} -> IoU [Na,Nb] - """ - if a.numel() == 0 or b.numel() == 0: - dev = a.device if a.numel() > 0 else (b.device if b.numel() > 0 else torch.device("cpu")) - Na = a.shape[0] if a.numel() > 0 else 0 - Nb = b.shape[0] if b.numel() > 0 else 0 - return torch.zeros((Na, Nb), device=dev, dtype=torch.float32) - - a = a.float() - b = b.float() - inter = torch.einsum("nhw,mhw->nm", a, b) # [Na,Nb] - - ua = a.sum(dim=(1,2))[:, None] # [Na,1] - ub = b.sum(dim=(1,2))[None, :] # [1,Nb] - - union = ua + ub - inter + eps # [Na,Nb] - return inter / union - -def _abiou(proposals, gts): - """ - Average Best IoU (coverage metric). - proposals: [Np,H,W] {0,1}, gts: [Ng,H,W] {0,1} - """ - if gts.numel() == 0 and proposals.numel() == 0: - dev = proposals.device if proposals.numel() > 0 else gts.device - return torch.tensor(1.0, device=dev, dtype=torch.float32) - if gts.numel() == 0 or proposals.numel() == 0: - dev = proposals.device if proposals.numel() > 0 else gts.device - return torch.tensor(0.0, device=dev, dtype=torch.float32) - iou = _mask_iou(gts, proposals) # [Ng,Np] - best = iou.max(dim=1).values # [Ng] - return best.mean() - -# --------------------- Utilities --------------------- - -def _to_bool_tensor(x: Union[np.ndarray, torch.Tensor], device: torch.device) -> torch.Tensor: - """ - Accepts HxW or [N,H,W] arrays/tensors, binarizes (>0) and returns bool tensor on device with shape [N,H,W]. - """ - if isinstance(x, np.ndarray): - arr = x - if arr.ndim == 2: - arr = arr[None, ...] - t = torch.from_numpy(arr) - elif isinstance(x, torch.Tensor): - t = x - if t.ndim == 2: - t = t.unsqueeze(0) - else: - raise TypeError(f"Unsupported mask type: {type(x)}") - # binarize and cast to bool - t = (t != 0) - return t.to(device=device, dtype=torch.bool) - - -def _downsample_bool_masks(m: torch.Tensor, factor: int) -> torch.Tensor: - """ - Downsample boolean masks by a small integer factor via max-pooling (keeps foreground coverage). - m: [N,H,W] bool - """ - if factor <= 1 or m.numel() == 0: - return m - # reshape for pooling - N, H, W = m.shape - H2 = H // factor - W2 = W // factor - if H2 == 0 or W2 == 0: - return m - # crop to divisible - m = m[:, :H2 * factor, :W2 * factor] - # convert to float for pooling-like reduction via unfold - mf = m.float() - mf = mf.unfold(1, factor, factor).unfold(2, factor, factor) # [N, H2, W2, f, f] - # max over the small window -> any(True) - mf = mf.contiguous().view(N, H2, W2, -1).max(dim=-1).values - return (mf > 0).to(dtype=torch.bool) +# Subset of IoU thresholds, as requested: +AR_THRESHOLDS = np.array([0.50, 0.65, 0.75, 0.85], dtype=np.float32) +# ------------------------ Decoder-side helpers ------------------------ -# --------------------- IoU (vectorized) --------------------- - -def _pairwise_iou_bool(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: +def _binary_iou(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: """ - a: [Na,H,W] bool, b: [Nb,H,W] bool -> IoU [Na,Nb] float32 - Vectorized via flatten + matmul. Stays on device. + Fast IoU for binary masks (boolean or {0,1} tensors). + Shapes: a,b: (H, W) """ - Na = a.shape[0] - Nb = b.shape[0] - if Na == 0 or Nb == 0: - return a.new_zeros((Na, Nb), dtype=torch.float32) - - # Flatten (reshape tolerates non-contiguous inputs) - a_f = a.reshape(Na, -1).float() - b_f = b.reshape(Nb, -1).float() - - # Areas and intersections - areas_a = a_f.sum(dim=1) # [Na] - areas_b = b_f.sum(dim=1) # [Nb] - inter = a_f @ b_f.t() # [Na,Nb] - - # Unions - union = areas_a[:, None] + areas_b[None, :] - inter + eps - return (inter / union).to(torch.float32) - + inter = (a & b).float().sum() + uni = (a | b).float().sum().clamp_min(eps) + return inter / uni -# --------------------- Metrics --------------------- - -def abiou_original(proposals: torch.Tensor, gts: torch.Tensor) -> torch.Tensor: - """ - ABIoU as mean over GTs of max IoU to any proposal (allows proposal reuse). - proposals, gts: [N,H,W] bool (on same device) +@torch.no_grad() +def decoder_prompt_miou(prd_masks: torch.Tensor, gt_masks: torch.Tensor) -> float: """ - dev = gts.device if gts.numel() > 0 else (proposals.device if proposals.numel() > 0 else torch.device("cpu")) - if gts.numel() == 0 and proposals.numel() == 0: - return torch.tensor(1.0, device=dev, dtype=torch.float32) - if gts.numel() == 0 or proposals.numel() == 0: - return torch.tensor(0.0, device=dev, dtype=torch.float32) - - iou = _pairwise_iou_bool(gts, proposals) # [Ng,Np] - best = iou.max(dim=1).values # [Ng] - return best.mean() - + Best-of-K decoder prompt mIoU. + Args: + prd_masks: [N, K, H, W] LOGITS from the decoder + gt_masks: [N, H, W] binary {0,1} + Returns: + mean over N of (max IoU over K), using SAM2's thresholding rule (logits > 0). + """ + N, K, H, W = prd_masks.shape + # SAM2 convention: threshold decoder outputs by logits > 0 + pred_bin = (prd_masks > 0) # bool, [N,K,H,W] + ious = [] + for n in range(N): + gt = gt_masks[n].bool() + if gt.sum() == 0: + continue # skip empty GT + best = torch.stack([_binary_iou(pred_bin[n, k], gt) for k in range(K)], dim=0).max() + ious.append(best) + if len(ious) == 0: + return float("nan") + return float(torch.stack(ious).mean().item()) -def abiou_one_to_one_greedy(proposals: torch.Tensor, gts: torch.Tensor) -> torch.Tensor: - """ - ABIoU with one-to-one greedy matching (no proposal reuse). - proposals, gts: [N,H,W] bool +@torch.no_grad() +def iou_head_calibration_from_decoder( + prd_masks: torch.Tensor, + prd_scores: torch.Tensor, + gt_masks: torch.Tensor, + num_bins: int = 15, +) -> Dict[str, Any]: """ - dev = gts.device if gts.numel() > 0 else (proposals.device if proposals.numel() > 0 else torch.device("cpu")) - if gts.numel() == 0 and proposals.numel() == 0: - return torch.tensor(1.0, device=dev, dtype=torch.float32) - if gts.numel() == 0 or proposals.numel() == 0: - return torch.tensor(0.0, device=dev, dtype=torch.float32) - - iou = _pairwise_iou_bool(gts, proposals) # [Ng,Np] - Ng, Np = iou.shape - used_g = torch.zeros(Ng, dtype=torch.bool, device=dev) - used_p = torch.zeros(Np, dtype=torch.bool, device=dev) - - matched_sum = torch.tensor(0.0, device=dev) - # Greedy loop at most min(Ng, Np) steps - for _ in range(min(Ng, Np)): - # mask used rows/cols by setting to -1 - iou_masked = iou.clone() - if used_g.any(): - iou_masked[used_g, :] = -1 - if used_p.any(): - iou_masked[:, used_p] = -1 - val, idx = torch.max(iou_masked.view(-1), dim=0) - if val <= 0: + Compare predicted IoU (sigmoid(prd_scores)) vs true IoU (from logits>0 masks). + Args: + prd_masks: [N,K,H,W] LOGITS + prd_scores: [N,K] raw IoU logits + gt_masks: [N,H,W] + Returns: + dict with calibration MAE, Brier, ECE, and a per-bin table (for diagnostics). + """ + N, K, H, W = prd_masks.shape + preds, trues = [], [] + pred_bin = (prd_masks > 0) # bool + + for n in range(N): + gt = gt_masks[n].bool() + if gt.sum() == 0: + continue + # True IoU per K proposal + true_iou_k = torch.stack([_binary_iou(pred_bin[n, k], gt) for k in range(K)], dim=0) # [K] + pred_iou_k = prd_scores[n].sigmoid().clamp(0, 1) # [K] + trues.append(true_iou_k) + preds.append(pred_iou_k) + + if len(preds) == 0: + return {"mae": float("nan"), "brier": float("nan"), "ece": float("nan"), "table": []} + + preds = torch.cat(preds) # [N'*K] + trues = torch.cat(trues) # [N'*K] + + mae = torch.abs(preds - trues).mean().item() + brier = torch.mean((preds - trues) ** 2).item() + + # Expected Calibration Error (ECE) over uniform bins + bins = np.linspace(0.0, 1.0, num_bins + 1) + pred_np, true_np = preds.cpu().numpy(), trues.cpu().numpy() + idx = np.clip(np.digitize(pred_np, bins, right=True) - 1, 0, num_bins - 1) + + ece = 0.0 + table = [] + total = len(pred_np) + for b in range(num_bins): + m = (idx == b) + n_b = int(m.sum()) + if n_b == 0: + table.append({"bin": f"[{bins[b]:.2f},{bins[b+1]:.2f})", "count": 0, + "mean_pred": None, "mean_true": None, "gap": None}) + continue + mp = float(pred_np[m].mean()) + mt = float(true_np[m].mean()) + gap = abs(mp - mt) + ece += (n_b / total) * gap + table.append({"bin": f"[{bins[b]:.2f},{bins[b+1]:.2f})", "count": n_b, + "mean_pred": round(mp, 4), "mean_true": round(mt, 4), "gap": round(gap, 4)}) + + return {"mae": mae, "brier": float(brier), "ece": float(ece), "table": table} + +# ------------------------ AMG proposal metrics ------------------------ + +def _iou(a: np.ndarray, b: np.ndarray) -> float: + """IoU for boolean numpy masks (H,W).""" + inter = np.logical_and(a, b).sum() + union = np.logical_or(a, b).sum() + return float(inter) / max(1.0, float(union)) + +def _iou_matrix(preds: List[np.ndarray], gts: List[np.ndarray]) -> np.ndarray: + """ + Build P x G IoU matrix for numpy boolean masks. + preds: list of predicted masks (H,W) + gts: list of gt masks (H,W) + """ + if len(preds) == 0 or len(gts) == 0: + return np.zeros((len(preds), len(gts)), dtype=np.float32) + M = np.zeros((len(preds), len(gts)), dtype=np.float32) + for i, p in enumerate(preds): + for j, g in enumerate(gts): + M[i, j] = _iou(p, g) + return M + +def _greedy_match(M: np.ndarray, tau: float) -> Tuple[int, int, int]: + """ + Greedy bipartite matching by IoU descending with threshold tau. + Returns: + TP, FP, FN + """ + P, G = M.shape + used_p, used_g, matches = set(), set(), [] + pairs = [(i, j, M[i, j]) for i in range(P) for j in range(G)] + pairs.sort(key=lambda x: x[2], reverse=True) + for i, j, iou in pairs: + if iou < tau: break - g_idx = idx // Np - p_idx = idx % Np - matched_sum = matched_sum + val - used_g[g_idx] = True - used_p[p_idx] = True - - # Average over ALL GTs (unmatched GTs count 0) - return matched_sum / max(Ng, 1) - - -def union_iou(proposals: torch.Tensor, gts: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: - """ - Pixel-set IoU between union of proposals and union of GTs. - proposals, gts: [N,H,W] bool - """ - dev = gts.device if gts.numel() > 0 else (proposals.device if proposals.numel() > 0 else torch.device("cpu")) - if gts.numel() == 0 and proposals.numel() == 0: - return torch.tensor(1.0, device=dev, dtype=torch.float32) - - if proposals.numel() > 0: - P = proposals.any(dim=0) - else: - # create zero map based on gts ref - H, W = gts.shape[-2], gts.shape[-1] - P = torch.zeros((H, W), dtype=torch.bool, device=dev) - - if gts.numel() > 0: - G = gts.any(dim=0) - else: - H, W = proposals.shape[-2], proposals.shape[-1] - G = torch.zeros((H, W), dtype=torch.bool, device=dev) - - inter = (P & G).sum().float() - uni = (P | G).sum().float() + eps - return (inter / uni).to(torch.float32) - - -# --------------------- Main evaluator --------------------- - -@torch.no_grad() -def automask_metrics( - sam2_model_or_predictor: Any, - images: List[Union[np.ndarray, torch.Tensor]], # HxW or HxWx3 (uint8 preferred) - gt_masks_per_image: List[List[Union[np.ndarray, torch.Tensor]]], # per-image list of HxW masks - *, - amg_kwargs: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = 20, - device: Optional[torch.device] = None, - autocast_ctx: Optional[Callable[[], Any]] = None, - downsample_factor: int = 1, - return_per_image: bool = False, + if i in used_p or j in used_g: + continue + used_p.add(i); used_g.add(j); matches.append((i, j)) + TP = len(matches); FP = P - TP; FN = G - TP + return TP, FP, FN + +def average_recall_amg( + amg_outputs: List[List[dict]], + gt_masks_list: List[List[torch.Tensor]], + iou_thresholds: np.ndarray = AR_THRESHOLDS, + max_proposals: int = None, ) -> Dict[str, Any]: """ - Run SAM2AutomaticMaskGenerator once per image and compute: - - ABIoU_one_to_one (greedy, no reuse) - - UnionIoU - - ABIoU_original (optional reference) - - Speed features: - - Single IoU matrix fuels both ABIoUs. - - Everything stays on GPU; masks are boolean. - - Optional downsample_factor (e.g., 2 or 4) for huge speedups. - + Average Recall across IoU thresholds. + Args: + amg_outputs: list over images of list[mask_dict]; each dict has 'segmentation' (np.bool array) and 'predicted_iou' + gt_masks_list: list over images of list[tensor HxW] for ground-truth instances + iou_thresholds: numpy array of IoU taus to average over + max_proposals: cap #proposals per image (after ranking by predicted_iou) for speed Returns: - { - 'ABIoU_one_to_one': float, - 'UnionIoU': float, - 'ABIoU_original': float, - 'num_images': int, - 'per_image': [ ... ] # if return_per_image - } + {"AR": scalar, "per_tau_recall": {tau: recall_tau}} + """ + recalls = [] + for tau in iou_thresholds: + tp = fn = 0 + for img_masks, gts in zip(amg_outputs, gt_masks_list): + # rank by predicted_iou then cap + if max_proposals is not None: + img_masks = sorted(img_masks, key=lambda d: d.get('predicted_iou', 0.0), reverse=True)[:max_proposals] + preds = [m['segmentation'].astype(bool) for m in img_masks] + gts_np = [g.cpu().numpy().astype(bool) for g in gts] + M = _iou_matrix(preds, gts_np) + tpi, _, fni = _greedy_match(M, float(tau)) + tp += tpi; fn += fni + denom = tp + fn + recalls.append(tp / denom if denom > 0 else np.nan) + + per_tau = {float(t): (None if np.isnan(r) else float(r)) for t, r in zip(iou_thresholds, recalls)} + return {"AR": float(np.nanmean(recalls)), "per_tau_recall": per_tau} + +def recall_at_k_amg( + amg_outputs: List[List[dict]], + gt_masks_list: List[List[torch.Tensor]], + ks: Tuple[int, ...] = (10, 50, 100), + iou_thresh: float = 0.5, +) -> Dict[str, Any]: """ + Recall@K at a fixed IoU threshold (default 0.5). + """ + out = {} + for K in ks: + tp = fn = 0 + for img_masks, gts in zip(amg_outputs, gt_masks_list): + sel = sorted(img_masks, key=lambda d: d.get('predicted_iou', 0.0), reverse=True)[:K] + preds = [m['segmentation'].astype(bool) for m in sel] + gts_np = [g.cpu().numpy().astype(bool) for g in gts] + M = _iou_matrix(preds, gts_np) + tpi, _, fni = _greedy_match(M, iou_thresh) + tp += tpi; fn += fni + denom = tp + fn + out[K] = tp / denom if denom > 0 else float('nan') + return {"Recall@K": out, "iou_thresh": iou_thresh} - # AMG defaults (safe, tweak as needed) - _amg = dict( - points_per_side=32, - points_per_batch=128, - pred_iou_thresh=0.7, - stability_score_thresh=0.92, - stability_score_offset=0.7, - crop_n_layers=1, - crop_n_points_downscale_factor=2, - box_nms_thresh=0.7, - use_m2m=False, - multimask_output=True, - ) - if amg_kwargs: - _amg.update(amg_kwargs) - - model_for_amg = getattr(sam2_model_or_predictor, "model", sam2_model_or_predictor) - mask_generator = SAM2AutomaticMaskGenerator(model=model_for_amg, **_amg) - - # Device - if device is None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Autocast (AMG forward only) - ac = autocast_ctx if autocast_ctx is not None else (lambda: nullcontext()) - - # Accumulators - one2one_vals, union_vals, abiou_orig_vals = [], [], [] - per_image_out = [] - - for img, gt_list in zip(images, gt_masks_per_image): - # ---- Ensure numpy uint8 image for AMG ---- - if isinstance(img, torch.Tensor): - img_np = img.detach().cpu().numpy() - else: - img_np = img - H, W = img_np.shape[:2] - - # ---- AMG forward ---- - with ac(): - proposals = mask_generator.generate(img_np) # list of dict - - # ---- Convert proposals -> [Np,H,W] bool on device ---- - if len(proposals) == 0: - prop_masks = torch.zeros((0, H, W), dtype=torch.bool, device=device) - else: - # sort by predicted_iou (or score), keep top_k - def _score(d): - return float(d.get("predicted_iou", d.get("score", 0.0))) - proposals.sort(key=_score, reverse=True) - if top_k is not None and top_k > 0: - proposals = proposals[:top_k] - masks_np = [(p["segmentation"] > 0).astype(np.uint8) for p in proposals] - prop_masks = torch.from_numpy(np.stack(masks_np, axis=0)).to(device=device, dtype=torch.bool) - prop_masks = prop_masks.contiguous() - - # ---- Convert GTs -> [Ng,H,W] bool on device ---- - if len(gt_list) == 0: - gt_masks = torch.zeros((0, H, W), dtype=torch.bool, device=device) - else: - gt_bool = [] - for g in gt_list: - if isinstance(g, torch.Tensor): - g_np = g.detach().cpu().numpy() - else: - g_np = g - gt_bool.append((g_np > 0).astype(np.uint8)) - gt_masks = torch.from_numpy(np.stack(gt_bool, axis=0)).to(device=device, dtype=torch.bool) - gt_masks = gt_masks.contiguous() - - # ---- Optional downsample (max-pool style) ---- - if downsample_factor > 1: - prop_masks_ds = _downsample_bool_masks(prop_masks, downsample_factor) - gt_masks_ds = _downsample_bool_masks(gt_masks, downsample_factor) - else: - prop_masks_ds = prop_masks - gt_masks_ds = gt_masks - - # ---- Metrics (single IoU matrix shared under the hood) ---- - m_one2one = abiou_one_to_one_greedy(prop_masks_ds, gt_masks_ds) - m_union = union_iou(prop_masks_ds, gt_masks_ds) - m_orig = abiou_original(prop_masks_ds, gt_masks_ds) - - one2one_vals.append(m_one2one) - union_vals.append(m_union) - abiou_orig_vals.append(m_orig) - - if return_per_image: - per_image_out.append({ - "ABIoU": float(m_one2one.detach().cpu()), - "ABIoU_original": float(m_orig.detach().cpu()), - "num_props": int(prop_masks.shape[0]), - "num_gt": int(gt_masks.shape[0]), - "H": int(H), - "W": int(W), - }) - - # ---- Averages ---- - if len(one2one_vals) == 0: - return { - "ABIoU_one_to_one": 0.0, - "UnionIoU": 0.0, - "ABIoU_original": 0.0, - "num_images": 0, - "per_image": [], - } - - ABIoU_one_to_one = torch.stack(one2one_vals).mean().item() - ABIoU_original_avg = torch.stack(abiou_orig_vals).mean().item() +# ------------------------ Wrapper for validation loop ------------------------ - out = { - "ABIoU": ABIoU_one_to_one, - "ABIoU_original": ABIoU_original_avg, - "num_images": len(one2one_vals), +@torch.no_grad() +def sam2_metrics(batch: Dict[str, Any], outputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], amg) -> Dict[str, Any]: + """ + Compute a SAM2-style metric bundle for a single validation batch. + Args: + batch: {"images": list[np.ndarray or torch tensor], "masks": list[list[torch.Tensor HxW]]} + outputs: (prd_masks, prd_scores, gt_masks) where + prd_masks: [N,K,H,W] logits, prd_scores: [N,K] logits, gt_masks: [N,H,W] {0,1} + amg: a SAM2AutomaticMaskGenerator instance + Returns: + dict with prompt_mIoU, IoU calibration (mae/brier/ece + table), + AR averaged over {0.50,0.65,0.75,0.85}, per-threshold recalls, + and Recall@{10,50,100}. Includes "num_images" for weighted averaging. + """ + prd_masks, prd_scores, gt_masks = outputs[:3] + + # Decoder-side metrics (cheap; no extra forward) + pm = decoder_prompt_miou(prd_masks, gt_masks) + cal = iou_head_calibration_from_decoder(prd_masks, prd_scores, gt_masks, num_bins=15) + + # AMG proposals (dominates runtime; run once, reuse for AR and R@K) + all_amg = [amg.generate(img) for img in batch["images"]] + all_gt = batch["masks"] + + ar = average_recall_amg(all_amg, all_gt, iou_thresholds=AR_THRESHOLDS, max_proposals=200) + # rK = recall_at_k_amg(all_amg, all_gt, ks=(10, 50), iou_thresh=0.5) + + return { + "prompt_miou": float(pm) if pm == pm else float("nan"), + "cal_mae": cal["mae"], + "cal_brier": cal["brier"], + "cal_ece": cal["ece"], + "cal_tables": cal["table"], # keep for inspection; don’t reduce across ranks + "AR": ar["AR"], + "num_images": len(batch["images"]), } - if return_per_image: - out["per_image"] = per_image_out - return out + # # "per_tau_recall": ar["per_tau_recall"], # keep for plotting; don’t reduce as a scalar + # "R@10": rK["Recall@K"][10], + # "R@50": rK["Recall@K"][50], \ No newline at end of file diff --git a/saber/finetune/train.py b/saber/finetune/train.py index 478addb..82c9357 100644 --- a/saber/finetune/train.py +++ b/saber/finetune/train.py @@ -34,14 +34,15 @@ def finetune_sam2( # Load data loaders train_loader = DataLoader( AutoMaskDataset( tomo_train, fib_train, transform=get_finetune_transforms(), - batch_size=batch_size, shuffle=True, - num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) + slabs_per_volume_per_epoch=10 ), + batch_size=batch_size, shuffle=True, + num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) + val_loader = DataLoader( AutoMaskDataset( - tomo_val, fib_val, - batch_size=batch_size, shuffle=False, - num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) - ) if (tomo_val or fib_val) else train_loader + tomo_val, fib_val, slabs_per_volume_per_epoch=10 ), + num_workers=4, pin_memory=True, collate_fn=collate_autoseg, + batch_size=batch_size, shuffle=False ) if (tomo_val or fib_val) else train_loader # Initialize trainer and train trainer = SAM2FinetuneTrainer( predictor, train_loader, val_loader ) @@ -65,10 +66,10 @@ def finetune(sam2_cfg: str, epochs: int, fib_train: str, fib_val: str, tomo_trai f"Fine Tuning SAM2 on {fib_train} and {fib_val} and {tomo_train} and {tomo_val} for {epochs} epochs" ) print(f"Using SAM2 Config: {sam2_cfg}") - print(f"Using Train Zarr: {tomo_train}") - print(f"Using Val Zarr: {tomo_val}") - print(f"Using Train Zarr: {fib_train}") - print(f"Using Val Zarr: {fib_val}") + print(f"Tomo Train Zarr: {tomo_train}") + print(f"Tomo Val Zarr: {tomo_val}") + print(f"Fib Train Zarr: {fib_train}") + print(f"Fib Val Zarr: {fib_val}") print(f"Using Number of Epochs: {epochs}") print(f"Using Batch Size: {batch_size}") print("--------------------------------") diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py index 67ec37e..e5e019a 100644 --- a/saber/finetune/trainer.py +++ b/saber/finetune/trainer.py @@ -1,7 +1,9 @@ from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from saber.finetune.helper import save_training_log -from saber.finetune.metrics import automask_metrics +from saber.finetune.abiou import automask_metrics from saber.finetune.losses import MultiMaskIoULoss +from saber.finetune.metrics import sam2_metrics from lightning import fabric import torch, os, optuna from tqdm import tqdm @@ -47,20 +49,20 @@ def _autocast(): # Initialize the loss function self.focal_alpha = 0.25 self.focal_gamma = 2.0 - self.supervise_all_iou = False - self.iou_use_l1_loss = False + self.supervise_all_iou = True + self.iou_use_l1_loss = True self.predict_multimask = True # Automask Generator Parameters self.amg_kwargs = dict( points_per_side=32, points_per_batch=128, - pred_iou_thresh=0.5, + pred_iou_thresh=0.7, stability_score_thresh=0.7, stability_score_offset=0.0, crop_n_layers=0, crop_n_points_downscale_factor=2, - box_nms_thresh=0.9, + box_nms_thresh=0.6, use_m2m=False, multimask_output=True, ) @@ -160,33 +162,47 @@ def forward_step(self, batch): high_res_features=hr_feats, ) - # 7) Upscale + stack GT - prd_masks = self.predictor._transforms.postprocess_masks( - low_res_masks, self.predictor._orig_hw[-1] - ) # [N,K,H,W] logits - gt_masks = torch.stack(gt_all, dim=0).float() # [N,H,W] + # 7) Upscale + stack GT + target_sizes = [self.predictor._orig_hw[int(b)] for b in inst_img_ix] # list of (H, W), len = N + # postprocess per instance (supports single size); do it in a loop then stack + upsampled = [] + for i in range(low_res_masks.shape[0]): # N + H, W = target_sizes[i] + up_i = self.predictor._transforms.postprocess_masks( + low_res_masks[i:i+1], (H, W) + ) # [1,K,H,W] + upsampled.append(up_i) + + prd_masks = torch.cat(upsampled, dim=0) # [N,K,H,W] + gt_masks = torch.stack(gt_all, dim=0).float() # [N,H,W] return prd_masks, prd_scores, gt_masks, inst_img_ix @torch.no_grad() - def validate_step(self, amg_kwargs=None, max_images=float('inf'), reduce_all_ranks=True): + def validate_step(self, max_images=float('inf'), best_metric='ABIoU'): """ Validate the model on the given batch. """ + # Set the model to evaluation mode self.predictor.model.eval() - # Local accumulators (weighted by number of images in each call) - abiou_sum = torch.tensor(0.0, device=self.device) - loss_sum = torch.tensor(0.0, device=self.device) - n_imgs = torch.tensor(0.0, device=self.device) - n_inst = torch.tensor(0.0, device=self.device) - - if amg_kwargs is None: - amg_kwargs = self.amg_kwargs - - # Each rank iterates only its shard (Fabric sets DistributedSampler for you) - num_images = 0 + # Initialize the AMG + amg = SAM2AutomaticMaskGenerator( + model=self.predictor.model, + **self.amg_kwargs + ) + + # --- local accumulators (tensor) --- + loss_keys = ["loss_total", "loss_iou", "loss_dice", "loss_mask"] + losses_sum = {k: torch.tensor(0.0, device=self.device) for k in loss_keys} + n_inst = torch.tensor(0.0, device=self.device) + n_imgs = torch.tensor(0.0, device=self.device) + + # Initialize the metrics sum + metrics_sum = {k: torch.tensor(0.0, device=self.device) for k in self.metric_keys} + + num_images_seen = 0 for batch in self.val_loader: # Compute Loss on decoder outputs @@ -194,69 +210,86 @@ def validate_step(self, amg_kwargs=None, max_images=float('inf'), reduce_all_ran if out[0] is None: continue # no instances in this batch prd_masks, prd_scores, gt_masks = out[:3] - batch_n = torch.tensor(float(gt_masks.shape[0]), device=self.device) + local_n = torch.tensor(float(gt_masks.shape[0]), device=self.device) with self.autocast(): - losses = self.loss_fn(prd_masks, prd_scores, gt_masks) - # convert to sum over instances - loss_sum += float(losses["loss_total"].detach().cpu()) * batch_n - n_inst += batch_n - - # Compute metrics on THIS batch only (keeps memory small & parallel) - m = automask_metrics( - self.predictor, # predictor or predictor.model (your function supports either) - batch["images"], # list[H×W×3] or list[H×W] - batch["masks"], # list[list[H×W]] - top_k=20, - device=self.device, - autocast_ctx=self.autocast, - amg_kwargs=amg_kwargs, - ) + batch_losses = self.loss_fn(prd_masks, prd_scores, gt_masks) + + if best_metric == 'ABIoU': + m = automask_metrics( + self.predictor, # predictor or predictor.model (your function supports either) + batch["images"], # list[H×W×3] or list[H×W] + batch["masks"], # list[list[H×W]] + top_k=20, + device=self.device, + autocast_ctx=self.autocast, + amg_kwargs=self.amg_kwargs, + ) + else: + m = sam2_metrics(batch, out, amg) + + # means → sums + for k in loss_keys: + # detach→cpu→float to avoid graph + dtype issues + losses_sum[k] += float(batch_losses[k].detach().cpu()) * local_n + n_inst += local_n # Weight by number of images so we can average correctly later - num = float(m["num_images"]) - abiou_sum += torch.tensor(m["ABIoU"] * num, device=self.device) - n_imgs += torch.tensor(num, device=self.device) - num_images += num - if num_images >= max_images: + img_count = float(m["num_images"]) + for k in self.metric_keys: + metrics_sum[k] += torch.tensor(m[k] * img_count, device=self.device) + n_imgs += torch.tensor(img_count, device=self.device) + + num_images_seen += img_count + if num_images_seen >= max_images: break - # Global reduction (sum across all ranks) - if self.use_fabric and reduce_all_ranks: - loss_sum = self.fabric.all_reduce(loss_sum, reduce_op="sum") - abiou_sum = self.fabric.all_reduce(abiou_sum, reduce_op="sum") - n_imgs = self.fabric.all_reduce(n_imgs, reduce_op="sum") - n_inst = self.fabric.all_reduce(n_inst, reduce_op="sum") - else: - abiou_sum = abiou_sum.sum() - n_imgs = n_imgs.sum() + # Reduce losses across ranks + losses_sum = self._all_reduce_sum(losses_sum) + n_inst = self._all_reduce_sum(n_inst) + n_imgs = self._all_reduce_sum(n_imgs) + metrics_sum = self._all_reduce_sum(metrics_sum) # Avoid divide-by-zero - denom = max(n_imgs.item(), 1.0) - loss_denom = max(n_inst.item(), 1.0) - return { - "loss": (loss_sum / loss_denom).item(), - "ABIoU": (abiou_sum / denom).item(), - "num_images": int(denom), + img_denom = max(n_imgs.item(), 1.0) + inst_denom = max(n_inst.item(), 1.0) + + out = { + "loss_total": (losses_sum["loss_total"] / inst_denom).item(), + "loss_iou": (losses_sum["loss_iou"] / inst_denom).item(), + "loss_dice": (losses_sum["loss_dice"] / inst_denom).item(), + "loss_mask": (losses_sum["loss_mask"] / inst_denom).item(), + "num_images": int(img_denom), } + out.update({k: (metrics_sum[k] / img_denom).item() for k in self.metric_keys}) + + return out - def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 10): + def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 100): """ Fine Tune SAM2 on the given data. """ # Initialize the loss function self.loss_fn = MultiMaskIoULoss( - weight_dict={"loss_mask": 1.0, "loss_dice": 1.0, "loss_iou": 0.15}, + weight_dict={"loss_mask": 10.0, "loss_dice": 1.0, "loss_iou": 1.0}, focal_alpha=self.focal_alpha, focal_gamma=self.focal_gamma, supervise_all_iou=self.supervise_all_iou, iou_use_l1_loss=self.iou_use_l1_loss ) + # Initialize the metric keys + if best_metric == 'ABIoU': + self.metric_keys = ['ABIoU'] + else: + self.metric_keys = ['prompt_miou', 'cal_mae', 'cal_brier', 'cal_ece', 'cal_tables', + 'AR', 'R@10', 'R@50', 'R@100'] + # Cosine scheduler w/Warmup ---- - warmup_epochs = max(int(0.05 * num_epochs), 1) - self.warmup_sched = LinearLR(self.optimizer, start_factor=1e-3, total_iters=warmup_epochs) + # warmup_epochs = max(int(0.01 * num_epochs), 1) + warmup_epochs = 5 + self.warmup_sched = LinearLR(self.optimizer, start_factor=0.1, total_iters=warmup_epochs) self.cosine_sched = CosineAnnealingLR(self.optimizer, T_max=(num_epochs - warmup_epochs), eta_min=1e-6) self.scheduler = SequentialLR(self.optimizer, [self.warmup_sched, self.cosine_sched], milestones=[warmup_epochs]) @@ -267,25 +300,33 @@ def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 10): else: pbar = None - best_metric_value = float('-inf') self.optimizer.zero_grad() + best_metric_value = float('-inf') + # Main Loop for epoch in range(num_epochs): # Train - epoch_loss_train = 0 - self.predictor.model.train() + # at start of each epoch + if hasattr(self.train_loader, "sampler") and hasattr(self.train_loader.sampler, "set_epoch"): + self.train_loader.sampler.set_epoch(epoch) + if (epoch+1) % resample_frequency == 0: self.train_loader.dataset.resample_epoch() + + self.predictor.model.train() for batch in self.train_loader: out = self.forward_step(batch) if out[0] is None: continue prd_masks, prd_scores, gt_masks = out[:3] with self.autocast(): - losses = self.loss_fn(prd_masks, prd_scores, gt_masks) + batch_losses = self.loss_fn(prd_masks, prd_scores, gt_masks) if self.use_fabric: - self.fabric.backward(losses['loss_total']) + self.fabric.backward(batch_losses['loss_total']) else: - losses['loss_total'].backward() + batch_losses['loss_total'].backward() + + # number of instances this rank used to compute its per-batch means + _local_n = gt_masks.shape[0] # (optional) gradient clip: if self.use_fabric: @@ -303,39 +344,85 @@ def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 10): self.optimizer.step() self.optimizer.zero_grad() - epoch_loss_train += float(losses["loss_total"].detach().cpu()) + losses = self._reduce_losses(batch_losses, _local_n) # Learning Rate Scheduler self.scheduler.step() # Validate - if (epoch+1) % 500 == 0: - metrics = self.amg_param_tuner() + metrics = {} + if (epoch+1) % 1e4 == 0: + metrics['val'] = self.amg_param_tuner() else: - metrics = self.validate_step() + metrics['val'] = self.validate_step() + metrics['train'] = losses # Print Only on Rank 0 if self.is_global_zero: - # Print Metrics - metrics['train'] = {'loss': epoch_loss_train/len(self.train_loader)} pbar.set_postfix({ - "train_loss": f"{metrics['train']['loss']:.4f}", - "val_loss": f"{metrics['loss']:.4f}", - "ABIoU": f"{metrics['ABIoU']:.4f}", + "train_loss": f"{metrics['train']['loss_total']:.4f}", + "val_loss": f"{metrics['val']['loss_total']:.4f}", + f"val_{best_metric}": f"{metrics['val'][best_metric]:.4f}", }) pbar.update(1) # Save Training Log metrics['epoch'] = epoch - metrics['lr'] = self.scheduler.get_last_lr()[0] - save_training_log(metrics, self.save_dir) + metrics['lr_mask'] = self.scheduler.get_last_lr()[0] + metrics['lr_prompt'] = self.scheduler.get_last_lr()[1] + save_training_log(metrics, self.save_dir, self.metric_keys) # Save Model if best metric is achieved - if metrics[best_metric] > best_metric_value: - best_metric_value = metrics[best_metric] - ckpt = {"model": self.predictor.model.state_dict()} + ckpt = {"model": self.predictor.model.state_dict()} + metric_value = metrics['val'].get(best_metric) + if metric_value > best_metric_value: + best_metric_value = metric_value torch.save(ckpt, f"{self.save_dir}/best_model.pth") print(f"Best {best_metric} saved!") + else: + torch.save(ckpt, f"{self.save_dir}/bad_model.pth") + + def _reduce_losses(self, losses, num_elems: int = None): + """ + Reduce the losses across ranks. + """ + key_map = { + "loss_iou": "loss_iou", + "loss_dice": "loss_dice", + "loss_mask": "loss_mask", + "loss_total": "loss_total", + } + count = torch.tensor(float(num_elems if num_elems is not None else 1.0), device=self.device) + out = {} + if self.use_fabric: + global_count = self.fabric.all_reduce(count, reduce_op="sum") + for long_k, short_k in key_map.items(): + if long_k not in losses: + continue + num = torch.tensor(float(losses[long_k].detach().item()), device=self.device) * count + global_num = self.fabric.all_reduce(num, reduce_op="sum") + out[short_k] = (global_num / torch.clamp(global_count, min=1.0)).item() + else: + for long_k, short_k in key_map.items(): + if long_k in losses: + out[short_k] = float(losses[long_k].detach().item()) + return out + + def _all_reduce_sum(self, x): + if not self.use_fabric: + return x + if isinstance(x, torch.Tensor): + return self.fabric.all_reduce(x, reduce_op="sum") + if isinstance(x, dict): + return {k: self.fabric.all_reduce(v, reduce_op="sum") for k, v in x.items()} + raise TypeError(f"_all_reduce_sum expects Tensor or dict[str,Tensor], got {type(x)}") + + # def _all_reduce_sum(self, x: torch.Tensor) -> torch.Tensor: + # """ + # """ + # return self.fabric.all_reduce(x, reduce_op="sum") if self.use_fabric else x + +############### Experimental - Automatic Mask Generator Tuning ############### def amg_param_tuner(self, n_trials=10): """ From 41d420053bb4740023427d571eb7e7eeaa86617e Mon Sep 17 00:00:00 2001 From: root Date: Tue, 30 Sep 2025 04:12:50 +0000 Subject: [PATCH 11/15] added support for training with multiple prompts --- .../classifier/preprocess/split_merge_data.py | 63 +++- saber/finetune/losses.py | 5 +- saber/finetune/prep.py | 300 ++++++++++++++++++ saber/finetune/trainer.py | 196 +++++++++--- 4 files changed, 501 insertions(+), 63 deletions(-) create mode 100644 saber/finetune/prep.py diff --git a/saber/classifier/preprocess/split_merge_data.py b/saber/classifier/preprocess/split_merge_data.py index 5090833..d00cd8d 100644 --- a/saber/classifier/preprocess/split_merge_data.py +++ b/saber/classifier/preprocess/split_merge_data.py @@ -1,9 +1,37 @@ from sklearn.model_selection import train_test_split from typing import List, Tuple, Dict, Optional +from zarr.convenience import copy as zarr_copy from pathlib import Path import click, zarr, os import numpy as np +def copy_like(src_arr, dst_group, path: str): + # Ensure parent groups for nested paths like "labels/0" + parent = dst_group + parts = path.split('/') + for p in parts[:-1]: + parent = parent.require_group(p) + + leaf = parts[-1] + # Create (or reuse) the dst array with identical metadata + dst_arr = parent.require_dataset( + leaf, + shape=src_arr.shape, + dtype=src_arr.dtype, + chunks=src_arr.chunks, + compressor=src_arr.compressor, + filters=src_arr.filters, + order=src_arr.order, + fill_value=src_arr.fill_value, + # Optional (if your sources use it): + **({"dimension_separator": getattr(src_arr, "dimension_separator", None)} + if hasattr(src_arr, "dimension_separator") else {}) + ) + dst_arr[:] = src_arr[:] # fast data copy + # Preserve array attrs + if src_arr.attrs: + dst_arr.attrs.update(src_arr.attrs) + def split( input: str, ratio: float, @@ -57,19 +85,25 @@ def split( items = ['0', 'labels/0', 'labels/rejected'] print('Copying data to train zarr file...') for key in train_keys: - train_zarr.create_group(key) # Explicitly create the group first - copy_attributes(zfile[key], train_zarr[key]) + dst_grp = train_zarr.require_group(key) + copy_attributes(zfile[key], dst_grp) for item in items: - train_zarr[key][item] = zfile[key][item][:] # [:] ensures a full copy - copy_attributes(zfile[key]['labels'], train_zarr[key]['labels']) + try: + copy_like(zfile[key][item], dst_grp, item) + copy_attributes(zfile[key][item], dst_grp[item]) + except Exception as e: + pass print('Copying data to validation zarr file...') for key in val_keys: - val_zarr.create_group(key) # Explicitly create the group first - copy_attributes(zfile[key], val_zarr[key]) + dst_grp = val_zarr.require_group(key) + copy_attributes(zfile[key], dst_grp) for item in items: - val_zarr[key][item] = zfile[key][item][:] # [:] ensures a full copy - copy_attributes(zfile[key]['labels'], val_zarr[key]['labels']) + try: + copy_like(zfile[key][item], dst_grp, item) + copy_attributes(zfile[key][item], dst_grp[item]) + except Exception as e: + pass # Print summary print(f"\nSplit Summary:") @@ -124,18 +158,16 @@ def merge(inputs: List[str], output: str): write_key = session_label + '_' + key # Create the group and copy its attributes - new_group = mergedZarr.create_group(write_key) # Explicitly create the group first - copy_attributes(zfile[key], new_group) + dst_grp = mergedZarr.require_group(write_key) + copy_attributes(zfile[key], dst_grp) # Copy the data arrays for item in items: try: - # [:] ensures a full copy - mergedZarr[write_key][item] = zfile[key][item][:] + copy_like(zfile[key][item], dst_grp, item) + copy_attributes(zfile[key][item], dst_grp[item]) except Exception as e: pass - # Copy attributes for labels subgroup - copy_attributes(zfile[key]['labels'], new_group['labels']) # Copy all attributes from the last input zarr file for attr_name, attr_value in zfile.attrs.items(): @@ -194,8 +226,5 @@ def copy_attributes(source, destination): """ if hasattr(source, 'attrs') and source.attrs: destination.attrs.update(source.attrs) - -if __name__ == '__main__': - cli() diff --git a/saber/finetune/losses.py b/saber/finetune/losses.py index 0818c69..3e3fe4a 100644 --- a/saber/finetune/losses.py +++ b/saber/finetune/losses.py @@ -52,8 +52,6 @@ def forward(self, prd_masks, prd_scores, gt_masks): # ---- Compute hard predictions and true IoU per proposal (no grad) ---------- with torch.no_grad(): - # probs_k = prd_masks.sigmoid() # [N,K,H,W] - # pred_bin = (probs_k > 0.0).to(gt_masks.dtype) # [N,K,H,W] pred_bin = (prd_masks > 0.0).to(gt_masks.dtype) # [N,K,H,W] gt_k = gt_masks[:, None].expand_as(pred_bin) # [N,K,H,W] inter = (pred_bin * gt_k).sum(dim=(2, 3)) # [N,K] @@ -71,7 +69,8 @@ def forward(self, prd_masks, prd_scores, gt_masks): ).view(N, K) # [N,K] seg_loss_per_k = focal_per_k + dice_per_k # [N,K] - best_idx = seg_loss_per_k.argmin(dim=1) # [N] choose lowest seg loss + # best_idx = seg_loss_per_k.argmin(dim=1) # [N] choose lowest seg loss + best_idx = true_iou_k.argmax(dim=1) # [N] choose highest IoU row = torch.arange(N, device=device) logits_star = prd_masks[row, best_idx] # [N,H,W] diff --git a/saber/finetune/prep.py b/saber/finetune/prep.py new file mode 100644 index 0000000..ae74945 --- /dev/null +++ b/saber/finetune/prep.py @@ -0,0 +1,300 @@ +""" +Convert copick data to 3D zarr format with simple JSON segmentation mapping. +Saves full 3D volumes and creates a simple ID-to-organelle mapping sorted by volume. +""" + +from saber.utils.zarr_writer import add_attributes +from typing import Dict, List, Optional, Tuple +from copick_utils.io import readers +from skimage.measure import label +import copick, json, click +from tqdm import tqdm +import numpy as np +import zarr + + +def process_run_3d_simple( + run, + voxel_spacing: float, + tomo_alg: str, + organelle_names: List[str], + min_component_volume: int = 100, + user_id: Optional[str] = None +) -> Tuple[np.ndarray, np.ndarray, Dict[str, str]]: + """ + Process a single run in full 3D with simple volume-based sorting. + + Returns: + volume_3d: Full 3D tomogram + seg_3d: 3D segmentation with unique labels sorted by volume + id_to_organelle: Dictionary mapping mask indices to organelle names + """ + + click.echo(" Loading tomogram...") + + # Get tomogram data + vs = run.get_voxel_spacing(voxel_spacing) + tomograms = vs.get_tomograms(tomo_alg) + + if not tomograms: + raise ValueError(f"No tomograms found for run {run.name}") + + volume_3d = tomograms[0].numpy() + click.echo(f" Tomogram shape: {volume_3d.shape}") + + components = [] + offset = 1 + temp_seg_3d = np.zeros_like(volume_3d, dtype=np.uint32) + + for organelle_name in organelle_names: + try: + click.echo(f" Processing {organelle_name}...") + seg = readers.segmentation(run, voxel_spacing, organelle_name, user_id=user_id) + + if seg.shape != volume_3d.shape: + temp_seg_3d = np.zeros_like(seg, dtype=np.uint32) + + if seg is not None: + # Convert to binary and separate connected components + binary_mask = (seg > 0.5).astype(np.uint8) + labeled_mask = label(binary_mask, connectivity=3) + + # Most efficient: process in a single pass + unique_labels, counts = np.unique(labeled_mask[labeled_mask > 0], return_counts=True) + + for label_val, vol in zip(unique_labels, counts): + if vol >= min_component_volume: + temp_seg_3d[labeled_mask == label_val] = offset + components.append((organelle_name, offset, vol)) + offset += 1 + else: + click.echo(" No segmentation found") + except Exception as e: + click.echo(f" Error processing {organelle_name}: {e}") + + # Sort by volume (smallest first, so small objects are on top in GUI) + components.sort(key=lambda x: x[2]) + + # Create final segmentation with remapped labels and the mapping dictionary + seg_3d = np.zeros_like(temp_seg_3d, dtype=np.uint16) + id_to_organelle: Dict[str, str] = {} + + for new_label, (organelle_name, old_label, _volume) in enumerate(components, start=1): + seg_3d[temp_seg_3d == old_label] = new_label + id_to_organelle[str(new_label)] = organelle_name + + return volume_3d, seg_3d, id_to_organelle + + +def convert_copick_to_3d_zarr( + config_path: str, + output_zarr_path: str, + output_json_path: Optional[str], + voxel_spacing: float, + tomo_alg: str, + specific_runs: Optional[List[str]], + min_component_volume: int, + compress: bool, + user_id: Optional[str], +): + """ + Convert copick data to 3D zarr format with JSON segmentation mapping. + """ + + # Initialize copick + root = copick.from_file(config_path) + + # Get organelle names (non-particle objects) + organelle_names = [obj.name for obj in root.pickable_objects if not obj.is_particle] + + # Optional: filter out specific organelles + organelle_names = [x for x in organelle_names if "membrane" not in x] + click.echo(f"Found organelles: {organelle_names}") + + # Set default JSON output path + if output_json_path is None: + output_json_path = output_zarr_path.replace(".zarr", "_mapping.json") + + # Initialize zarr store + compressor = zarr.Blosc(cname="zstd", clevel=2) if compress else None + store = zarr.DirectoryStore(output_zarr_path) + zroot = zarr.group(store=store, overwrite=True) + + # Master mapping dictionary + master_mapping: Dict[str, Dict[str, str]] = {} + + # Determine which runs to process + runs_to_process = specific_runs if specific_runs else [run.name for run in root.runs] + + for run_name in tqdm(runs_to_process, desc="Processing runs"): + click.echo(f"\nProcessing run: {run_name}") + run = root.get_run(run_name) + + try: + # Process the 3D data + volume_3d, seg_3d, id_to_organelle = process_run_3d_simple( + run=run, + voxel_spacing=voxel_spacing, + tomo_alg=tomo_alg, + organelle_names=organelle_names, + min_component_volume=min_component_volume, + user_id=user_id, + ) + + # Create zarr group for this run + run_group = zroot.create_group(run_name) + + # Save 3D volume + voxel_spacing_nm = voxel_spacing / 10 # Convert to nm + run_group.create_dataset( + "0", + data=volume_3d, + chunks=(1, volume_3d.shape[1], volume_3d.shape[2]), # Chunk by slices + compressor=compressor, + dtype=volume_3d.dtype, + ) + add_attributes(run_group, voxel_spacing_nm, True, voxel_spacing_nm) + + # Save 3D segmentation/label volume + label_group = run_group.create_group("labels") + label_group.create_dataset( + "0", + data=seg_3d, + chunks=(1, volume_3d.shape[1], volume_3d.shape[2]), + compressor=compressor, + dtype=seg_3d.dtype, + ) + add_attributes(label_group, voxel_spacing_nm, True, voxel_spacing_nm) + + # Add to master mapping + master_mapping[run_name] = id_to_organelle + + click.echo(f" ✅ Saved {run_name} with {len(id_to_organelle)} segmentations") + + except Exception as e: + click.echo(f" ❌ Error processing {run_name}: {e}") + continue + + # Save master JSON mapping + with open(output_json_path, "w") as f: + json.dump(master_mapping, f, indent=2) + + click.echo("\n🎉 Conversion complete!") + click.echo(f"📁 Zarr output: {output_zarr_path}") + click.echo(f"📄 JSON mapping: {output_json_path}") + click.echo(f"📊 Total runs processed: {len(master_mapping)}") + + +def load_3d_zarr_data(zarr_path: str, run_name: str) -> Tuple[np.ndarray, np.ndarray, Dict]: + """ + Load 3D data from zarr for a specific run. + + Returns: + volume: 3D tomogram + labels: 3D label volume + id_to_organelle: Mapping of label values to organelle names + """ + store = zarr.DirectoryStore(zarr_path) + zroot = zarr.group(store=store, mode="r") + + if run_name not in zroot: + raise ValueError(f"Run {run_name} not found in zarr") + + run_group = zroot[run_name] + + # NOTE: This mirrors the writer above: datasets saved under '0' and 'labels/0'. + volume = run_group["0"][:] + labels = run_group["labels"]["0"][:] + # If you want to persist per-run mapping inside zarr attrs instead of the external JSON, + # you can set and read it here. Currently, mappings are stored in the external JSON. + id_to_organelle = {} + + return volume, labels, id_to_organelle + + +@click.command(context_settings={"show_default": True}) +@click.option( + "--config", + "config_path", + type=click.Path(exists=True, dir_okay=False, readable=True, path_type=str), + default="config.json", + help="Path to copick config file.", +) +@click.option( + "--output-zarr", + "output_zarr_path", + type=click.Path(dir_okay=True, writable=True, path_type=str), + required=True, + help="Output path for the zarr directory.", +) +@click.option( + "--output-json", + "output_json_path", + type=click.Path(dir_okay=False, writable=True, path_type=str), + default=None, + help="Output path for JSON mapping (defaults to _mapping.json).", +) +@click.option( + "--voxel-size", + "voxel_spacing", + type=float, + default=7.84, + help="Voxel spacing for the tomogram data (Å).", +) +@click.option( + "--tomo-alg", + type=str, + default="wbp-denoised-ctfdeconv", + help="Tomogram algorithm to use for processing.", +) +@click.option( + "--specific-run", + "specific_runs", + multiple=True, + help="Process only specific runs. Repeat this option for multiple runs.", +) +@click.option( + "--min-component-volume", + type=int, + default=10000, + help="Minimum connected-component volume (in voxels).", +) +@click.option( + "--user-id", + type=str, + default=None, + help="UserID for accessing segmentation.", +) +@click.option( + "--no-compress", + is_flag=True, + default=False, + help="Disable compression for zarr storage.", +) +def main( + config_path: str, + output_zarr_path: str, + output_json_path: Optional[str], + voxel_spacing: float, + tomo_alg: str, + specific_runs: Optional[List[str]], + min_component_volume: int, + user_id: Optional[str], + no_compress: bool, +): + """Convert copick data to 3D zarr format with JSON segmentation mapping.""" + convert_copick_to_3d_zarr( + config_path=config_path, + output_zarr_path=output_zarr_path, + output_json_path=output_json_path, + voxel_spacing=voxel_spacing, + tomo_alg=tomo_alg, + specific_runs=list(specific_runs) if specific_runs else None, + min_component_volume=min_component_volume, + compress=not no_compress, + user_id=user_id, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py index e5e019a..f63710a 100644 --- a/saber/finetune/trainer.py +++ b/saber/finetune/trainer.py @@ -4,12 +4,13 @@ from saber.finetune.abiou import automask_metrics from saber.finetune.losses import MultiMaskIoULoss from saber.finetune.metrics import sam2_metrics +import torch, os, optuna, random +import torch.nn.functional as F from lightning import fabric -import torch, os, optuna from tqdm import tqdm class SAM2FinetuneTrainer: - def __init__(self, predictor, train_loader, val_loader): + def __init__(self, predictor, train_loader, val_loader, seed=42): # Store the predictor self.predictor = predictor @@ -63,13 +64,14 @@ def _autocast(): crop_n_layers=0, crop_n_points_downscale_factor=2, box_nms_thresh=0.6, - use_m2m=False, - multimask_output=True, + use_m2m=True, + multimask_output=False, ) self.nAMGtrials = 10 # Initialize the use_boxes flag - self.use_boxes = False + self.use_boxes = True + self._rng = random.Random(seed) # Initialize the save directory self.save_dir = 'results' @@ -98,18 +100,102 @@ def _stack_image_embeddings_from_predictor(self): hr_feats = [torch.stack([lvl[b] for b in range(B)], dim=0).to(self.device) for lvl in hr] return image_embeds, hr_feats + def _determine_sampling(self, N, p_points=0.5, p_box=0.15, p_mask=0.2, p_mask_box=0.15): + """ + Decide which prompt combo each instance uses. + Returns a list[int] of length N with codes: + 0 = points only + 1 = box + points + 2 = mask + points + 3 = mask + box + points + """ + # normalize to avoid drift if probs don't sum to 1 exactly + probs = [p_points, p_box, p_mask, p_mask_box] + s = sum(probs); probs = [p / s for p in probs] + # cumulative edges for a single uniform draw + e0 = probs[0] + e1 = e0 + probs[1] + e2 = e1 + probs[2] + + combo = [] + for _ in range(N): + r = self._rng.random() + if r < e0: combo.append(0) + elif r < e1: combo.append(1) + elif r < e2: combo.append(2) + else: combo.append(3) + return combo + + def _process_inputs(self, N, mask_logits_full, pts_all, lbl_all, boxes_full, combo): + """ + Build per-instance prompts to feed _prep_prompts(): + - trim points when also using box/mask (keep 1–3 anchors) + - pad clicks to (N, P, 2) and (N, P) with labels=-1 for ignored slots + - select boxes/mask_logits per instance based on combo + """ + device = self.device + + # Which instances use which prompts + use_boxes = torch.tensor([c in (1, 3) for c in combo], device=device) + use_masks = torch.tensor([c in (2, 3) for c in combo], device=device) + + # ---- Trim clicks (when box/mask present we keep a few anchors to avoid over-conditioning) + pts_trim, lbl_trim = [], [] + for i, (p, l) in enumerate(zip(pts_all, lbl_all)): + if combo[i] in (1, 2, 3) and p.shape[0] > 3: + pts_trim.append(p[:3]) + lbl_trim.append(l[:3]) + else: + pts_trim.append(p) + lbl_trim.append(l) + + # ---- Pad to dense tensors; labels=-1 means "ignore" for _prep_prompts + max_p = max((p.shape[0] for p in pts_trim), default=0) + pts_pad = torch.zeros((N, max_p, 2), device=device, dtype=torch.float32) + lbl_pad = torch.full((N, max_p), -1.0, device=device, dtype=torch.float32) + for i, (p, l) in enumerate(zip(pts_trim, lbl_trim)): + if p.numel(): + pts_pad[i, :p.shape[0]] = p.to(device, dtype=torch.float32) + lbl_pad[i, :l.shape[0]] = l.to(device, dtype=torch.float32) + + # ---- Ensure boxes_full exists & is float32; supply dummy box when unused + if boxes_full is None: + boxes_full = torch.tensor([[0, 0, 1, 1]], device=device, dtype=torch.float32).expand(N, 4) + else: + boxes_full = boxes_full.to(device, dtype=torch.float32) + + boxes_sel = torch.where( + use_boxes[:, None], + boxes_full, + torch.tensor([0, 0, 1, 1], device=device, dtype=torch.float32).expand_as(boxes_full) + ) + + # ---- Gate mask logits per instance (mask prompt when requested; zeros otherwise) + mask_logits_sel = torch.where( + use_masks[:, None, None], + mask_logits_full.to(device, dtype=torch.float32), + torch.zeros_like(mask_logits_full, device=device, dtype=torch.float32) + ) + + return pts_pad, lbl_pad, boxes_sel, mask_logits_sel + + def forward_step(self, batch): """ - Returns: prd_masks [N,K,H,W] logits, prd_scores [N,K], gt_masks [N,H,W], inst_img_ix [N] + Returns: + prd_masks: [N, K, H, W] (logits at image res) + prd_scores: [N, K] (predicted IoU/head) + gt_masks: [N, H, W] + inst_img_ix: [N] (which original image each instance came from) """ - images = batch["images"] # list of HxWx3 uint8 or float; predictor handles them + images = batch["images"] # list of B images (HxWx3); predictor handles types B = len(images) - # 1) Encode images once - self.predictor.set_image_batch(images) # caches features on predictor + # 1) encode once + self.predictor.set_image_batch(images) image_embeds_B, hr_feats_B = self._stack_image_embeddings_from_predictor() - # 2) Flatten instances across batch, move tensors to device + # 2) flatten instances inst_img_ix, gt_all, pts_all, lbl_all, box_all = [], [], [], [], [] for b in range(B): for m, p, l, bx in zip(batch["masks"][b], batch["points"][b], batch["labels"][b], batch["boxes"][b]): @@ -124,34 +210,59 @@ def forward_step(self, batch): return None, None, None, None inst_img_ix = torch.tensor(inst_img_ix, device=self.device, dtype=torch.long) - # 3) Pad clicks to (N,P,2) and (N,P) - P = max(p.shape[0] for p in pts_all) - pts_pad = torch.zeros((N, P, 2), device=self.device) - lbl_pad = torch.full((N, P), -1.0, device=self.device) # <- ignore - for i,(p,l) in enumerate(zip(pts_all, lbl_all)): - pts_pad[i, :p.shape[0]] = p - lbl_pad[i, :l.shape[0]] = l - - # Optional boxes - boxes = torch.stack(box_all, dim=0) if (self.use_boxes and len(box_all) > 0) else None - - # 4) Prompt encoding - mask_input, unnorm_coords, labels, _ = self.predictor._prep_prompts( - pts_pad, lbl_pad, - box=boxes, mask_logits=None, + # 3) prompt combos + combo = self._determine_sampling(N) + + # 4) boxes + boxes_full = torch.stack(box_all, dim=0).to(self.device, dtype=torch.float32) if len(box_all) > 0 else None + if boxes_full is None: + boxes_full = torch.tensor([[0, 0, 1, 1]], device=self.device, dtype=torch.float32).expand(N, 4) + + # 5) mask logits (+/-6) + gt_masks_bin = torch.stack([m.to(torch.float32) for m in gt_all], dim=0).to(self.device) + mask_logits_full = (gt_masks_bin * 2.0 - 1.0) * 6.0 + + # 6) build per-instance prompts + pts_pad, lbl_pad, boxes, mask_logits = self._process_inputs( + N, mask_logits_full, pts_all, lbl_all, boxes_full, combo + ) + has_any_mask = (mask_logits is not None) and (mask_logits.abs().sum() > 0) + + # 7) prep prompts (prompt-space outputs) + mask_input, point_coords, point_labels, boxes_input = self.predictor._prep_prompts( + pts_pad, lbl_pad, + box=boxes, + mask_logits=(mask_logits if has_any_mask else None), normalize_coords=True ) + + # --- shape fix + spatial size for dense mask prompt --- + Hf, Wf = image_embeds_B.shape[-2], image_embeds_B.shape[-1] + target_mask_h, target_mask_w = Hf * 4, Wf * 4 + + if mask_input is not None: + mask_input = mask_input.to(self.device, dtype=torch.float32) + if mask_input.dim() == 3: + mask_input = mask_input.unsqueeze(1) # [N,1,H,W] + elif mask_input.dim() == 4 and mask_input.shape[0] == 1 and mask_input.shape[1] > 1: + mask_input = mask_input.permute(1, 0, 2, 3).contiguous() # [N,1,H,W] + if mask_input.shape[1] != 1: + mask_input = mask_input[:, :1] + if mask_input.shape[-2:] != (target_mask_h, target_mask_w): + mask_input = F.interpolate(mask_input, (target_mask_h, target_mask_w), mode="bilinear", align_corners=False) + + # 8) encode prompts (use prompt-space tensors) sparse_embeddings, dense_embeddings = self.predictor.model.sam_prompt_encoder( - points=(unnorm_coords, labels), - boxes=boxes if self.use_boxes else None, - masks=None, + points=(point_coords, point_labels), + boxes=boxes_input, + masks=mask_input, ) - # 5) Gather per-instance image feats - image_embeds = image_embeds_B[inst_img_ix] # [N,C,H',W'] - hr_feats = [lvl[inst_img_ix] for lvl in hr_feats_B] # list of [N,C,H',W'] + # 9) gather image feats per instance + image_embeds = image_embeds_B[inst_img_ix] + hr_feats = [lvl[inst_img_ix] for lvl in hr_feats_B] - # 6) Decode + # 10) decode low_res_masks, prd_scores, _, _ = self.predictor.model.sam_mask_decoder( image_embeddings=image_embeds, image_pe=self.predictor.model.sam_prompt_encoder.get_dense_pe(), @@ -162,19 +273,17 @@ def forward_step(self, batch): high_res_features=hr_feats, ) - # 7) Upscale + stack GT - target_sizes = [self.predictor._orig_hw[int(b)] for b in inst_img_ix] # list of (H, W), len = N - # postprocess per instance (supports single size); do it in a loop then stack + # 11) upsample to image res + target_sizes = [self.predictor._orig_hw[int(b)] for b in inst_img_ix] upsampled = [] - for i in range(low_res_masks.shape[0]): # N + for i in range(low_res_masks.shape[0]): H, W = target_sizes[i] - up_i = self.predictor._transforms.postprocess_masks( - low_res_masks[i:i+1], (H, W) - ) # [1,K,H,W] + up_i = self.predictor._transforms.postprocess_masks(low_res_masks[i:i+1], (H, W)) upsampled.append(up_i) + prd_masks = torch.cat(upsampled, dim=0) - prd_masks = torch.cat(upsampled, dim=0) # [N,K,H,W] - gt_masks = torch.stack(gt_all, dim=0).float() # [N,H,W] + # 12) stack GT + gt_masks = torch.stack(gt_all, dim=0).float() return prd_masks, prd_scores, gt_masks, inst_img_ix @@ -283,8 +392,9 @@ def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 100): if best_metric == 'ABIoU': self.metric_keys = ['ABIoU'] else: - self.metric_keys = ['prompt_miou', 'cal_mae', 'cal_brier', 'cal_ece', 'cal_tables', - 'AR', 'R@10', 'R@50', 'R@100'] + self.metric_keys = [ + 'prompt_miou', 'cal_mae', 'cal_brier', 'cal_ece', + 'cal_tables', 'AR', 'R@10', 'R@50', 'R@100'] # Cosine scheduler w/Warmup ---- # warmup_epochs = max(int(0.01 * num_epochs), 1) From 352591ac31dc8d2b776a44023307ae8b32394ada Mon Sep 17 00:00:00 2001 From: root Date: Tue, 30 Sep 2025 23:15:31 +0000 Subject: [PATCH 12/15] make sure slices are biased for center of tomogram --- saber/finetune/dataset.py | 173 +++++++++++++++++++++++++++----------- saber/finetune/metrics.py | 5 +- saber/finetune/train.py | 5 +- saber/finetune/trainer.py | 10 ++- saber/main.py | 1 - 5 files changed, 139 insertions(+), 55 deletions(-) diff --git a/saber/finetune/dataset.py b/saber/finetune/dataset.py index f168cc9..b057d00 100644 --- a/saber/finetune/dataset.py +++ b/saber/finetune/dataset.py @@ -1,5 +1,6 @@ from scipy.ndimage import binary_erosion, binary_dilation from monai.transforms import Compose, EnsureChannelFirstd +from typing import Optional, Tuple, Union import saber.finetune.helper as helper from saber.utils import preprocessing from torch.utils.data import Dataset @@ -15,7 +16,8 @@ def __init__(self, slabs_per_volume_per_epoch: int = 10, slices_per_fib_per_epoch: int = 5, slab_thickness: int = 5, - seed: int = 42): + seed: int = 42, + shuffle: bool = True): """ Args: tomogram_zarr_path: Path to the tomogram zarr store @@ -38,6 +40,7 @@ def __init__(self, self.k_max = 100 self.transform = transform self.keep_fraction = 0.5 + self.shuffle = shuffle # Check if both data types are available if tomogram_zarr_path is None and fib_zarr_path is None: @@ -83,7 +86,6 @@ def __init__(self, self.seed = seed self._rng = np.random.RandomState(seed) - # Verbose Flag self.verbose = False @@ -112,6 +114,80 @@ def _estimate_zrange(self, key, band=(0.3, 0.7), threshold=0): max_val = vals.max() - self.slab_thickness + min_offset min_val = vals.min() + self.slab_thickness + min_offset return int(min_val), int(max_val) + + def _compute_indices( + self, + D: int, + N: int, + z_bounds: Optional[Tuple[Union[int, float], Union[int, float]]] = None, + ) -> np.ndarray: + """ + Precompute N slice indices for a volume of depth D. + + Args: + D: total #slices (depth). + N: total indices to return. + center_bias: fraction sampled from a Gaussian; the rest uniform. + sigma: std-dev for the Gaussian as a fraction of the (bounded) depth. + z_bounds: optional (z_min, z_max). If ints, treated as absolute slice + indices (inclusive). If floats in [0,1], treated as fractions + of D (e.g., (0.1, 0.9)). Indices are clamped into [0, D-1]. + + Returns: + np.ndarray of shape (N,) with integer slice indices in ascending order. + """ + assert D > 0 and N > 0 + + center_bias = 0.9 + sigma = 0.2 + + # --- Resolve bounds --- + if z_bounds is None: + z0, z1 = 0, D - 1 + else: + a, b = z_bounds + if isinstance(a, float) or isinstance(b, float): + # treat as fractions in [0,1] + a = int(np.floor(np.clip(a, 0.0, 1.0) * (D - 1))) + b = int(np.ceil (np.clip(b, 0.0, 1.0) * (D - 1))) + z0 = int(max(0, min(a, b))) + z1 = int(min(D - 1, max(a, b))) + if z1 < z0: + # empty window → fall back to full range + z0, z1 = 0, D - 1 + + # Window length (inclusive) + W = (z1 - z0 + 1) + if W <= 0: + z0, z1 = 0, D - 1 + W = D + + # --- How many from each sampler --- + n_g = int(round(center_bias * N)) + n_u = N - n_g + + # --- Gaussian over the bounded window --- + mu = z0 + 0.5 * (W - 1) + s = sigma * (W - 1) + if s <= 0: + # degenerate window or sigma=0 → just center + g = np.full(n_g, int(round(mu)), dtype=int) + else: + g = self._rng.normal(mu, s, size=n_g) + g = np.clip(np.round(g).astype(int), z0, z1) + + # --- Uniform over the bounded window --- + u = self._rng.randint(z0, z1 + 1, size=n_u, dtype=int) + + # --- Merge, unique, and pad if needed --- + idx = np.unique(np.concatenate([g, u])) + # If de-dup removed too many, top up with uniform draws + while idx.size < N: + extra = self._rng.randint(z0, z1 + 1, size=N - idx.size, dtype=int) + idx = np.unique(np.concatenate([idx, extra])) + + # Return exactly N indices (sorted for reproducibility) + return np.sort(idx[:N]) def resample_epoch(self): """ Generate new random samples for this epoch """ @@ -122,21 +198,22 @@ def resample_epoch(self): new_tomo_samples = [] for vol_idx in range(self.n_tomogram_volumes): - # Sample random z positions from this tomogram volume + # Sample Indices in [0, D-1], center-biased within the *window* z_min, z_max = self.tomo_shapes[vol_idx] - z_positions = self._rng.randint( - z_min, - z_max, - size=self.slabs_per_volume_per_epoch + z_rel = self._compute_indices( + z_max - z_min, + self.slabs_per_volume_per_epoch, + z_bounds=(z_min, z_max) ) - # Add to samples + + # Shift to absolute z and add to samples + z_positions = z_rel + z_min for z_pos in z_positions: new_tomo_samples.append((vol_idx, z_pos)) - - self.tomogram_samples = self._update_samples(self.tomogram_samples, new_tomo_samples) + self.tomogram_samples = new_tomo_samples # Shuffle samples - self._rng.shuffle(self.tomogram_samples) + if self.shuffle: self._rng.shuffle(self.tomogram_samples) # Sample random slices from each FIB volume if self.has_fib: @@ -145,51 +222,20 @@ def resample_epoch(self): for fib_idx in range(self.n_fib_volumes): fib_shape = self.fib_shapes[fib_idx] # Sample random z positions from this FIB volume - z_positions = self._rng.randint( - 0, - fib_shape[0], - size=self.slices_per_fib_per_epoch + z_positions = self._compute_indices( + fib_shape[0], + self.slices_per_fib_per_epoch, + z_bounds=(0, fib_shape[0]) ) for z_pos in z_positions: new_fib_samples.append((fib_idx, z_pos)) self.fib_samples = self._update_samples(self.fib_samples, new_fib_samples) - self._rng.shuffle(self.fib_samples) # Shuffle samples + if self.shuffle: self._rng.shuffle(self.fib_samples) # Shuffle samples # Set epoch length self.epoch_length = len(self.tomogram_samples) + len(self.fib_samples) - - def _update_samples(self, old, new): - """ - Return a mixed list with size == len(new): - - keep = min(round(len(new)*keep_fraction), len(old)) from 'old' - - add = len(new) - keep from 'new' - """ - target = len(new) - if target == 0: - return [] - - # choose keep set from *old* list, new set from *new* list - keep = min(int(round(target * self.keep_fraction)), len(old)) - add = target - keep - - # keep set from *old* list - if keep > 0: - keep_idx = self._rng.choice(len(old), size=keep, replace=False) - kept = [old[i] for i in keep_idx] - else: - kept = [] - - # add set from *new* list - if add > 0: - new_idx = self._rng.choice(len(new), size=add, replace=False) - added = [new[i] for i in new_idx] - else: - added = [] - - # return mixed list - return kept + added def __len__(self): return self.epoch_length @@ -404,4 +450,35 @@ def _package_image_item(self, "points": points_t, # list[#p x 2] float32 (xy) "labels": labels_t, # list[#p] all ones "boxes": boxes_t, # list[4] (x0,y0,x1,y1) - } \ No newline at end of file + } + + # def _update_samples(self, old, new): + # """ + # Return a mixed list with size == len(new): + # - keep = min(round(len(new)*keep_fraction), len(old)) from 'old' + # - add = len(new) - keep from 'new' + # """ + # target = len(new) + # if target == 0: + # return [] + + # # choose keep set from *old* list, new set from *new* list + # keep = min(int(round(target * self.keep_fraction)), len(old)) + # add = target - keep + + # # keep set from *old* list + # if keep > 0: + # keep_idx = self._rng.choice(len(old), size=keep, replace=False) + # kept = [old[i] for i in keep_idx] + # else: + # kept = [] + + # # add set from *new* list + # if add > 0: + # new_idx = self._rng.choice(len(new), size=add, replace=False) + # added = [new[i] for i in new_idx] + # else: + # added = [] + + # # return mixed list + # return kept + added \ No newline at end of file diff --git a/saber/finetune/metrics.py b/saber/finetune/metrics.py index 9752551..2a3d906 100644 --- a/saber/finetune/metrics.py +++ b/saber/finetune/metrics.py @@ -230,7 +230,7 @@ def sam2_metrics(batch: Dict[str, Any], outputs: Tuple[torch.Tensor, torch.Tenso all_gt = batch["masks"] ar = average_recall_amg(all_amg, all_gt, iou_thresholds=AR_THRESHOLDS, max_proposals=200) - # rK = recall_at_k_amg(all_amg, all_gt, ks=(10, 50), iou_thresh=0.5) + rK = recall_at_k_amg(all_amg, all_gt, ks=(10, 50, 100), iou_thresh=0.5) return { "prompt_miou": float(pm) if pm == pm else float("nan"), @@ -240,6 +240,9 @@ def sam2_metrics(batch: Dict[str, Any], outputs: Tuple[torch.Tensor, torch.Tenso "cal_tables": cal["table"], # keep for inspection; don’t reduce across ranks "AR": ar["AR"], "num_images": len(batch["images"]), + "R@10": rK["Recall@K"][10], + "R@50": rK["Recall@K"][50], + "R@100": rK["Recall@K"][100], } # # "per_tau_recall": ar["per_tau_recall"], # keep for plotting; don’t reduce as a scalar # "R@10": rK["Recall@K"][10], diff --git a/saber/finetune/train.py b/saber/finetune/train.py index 82c9357..2c9746b 100644 --- a/saber/finetune/train.py +++ b/saber/finetune/train.py @@ -34,18 +34,19 @@ def finetune_sam2( # Load data loaders train_loader = DataLoader( AutoMaskDataset( tomo_train, fib_train, transform=get_finetune_transforms(), - slabs_per_volume_per_epoch=10 ), + slabs_per_volume_per_epoch=20 ), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) val_loader = DataLoader( AutoMaskDataset( - tomo_val, fib_val, slabs_per_volume_per_epoch=10 ), + tomo_val, fib_val, slabs_per_volume_per_epoch=15 ), num_workers=4, pin_memory=True, collate_fn=collate_autoseg, batch_size=batch_size, shuffle=False ) if (tomo_val or fib_val) else train_loader # Initialize trainer and train trainer = SAM2FinetuneTrainer( predictor, train_loader, val_loader ) + # trainer.train( num_epochs, best_metric='AR' ) trainer.train( num_epochs ) @click.command() diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py index f63710a..1448b45 100644 --- a/saber/finetune/trainer.py +++ b/saber/finetune/trainer.py @@ -372,9 +372,13 @@ def validate_step(self, max_images=float('inf'), best_metric='ABIoU'): } out.update({k: (metrics_sum[k] / img_denom).item() for k in self.metric_keys}) + if 'cal_tables' in m and self.is_global_zero: + # (optional) keep the last batch's cal_tables on rank0 for inspection: + out["cal_tables"] = m.get("cal_tables", None) + return out - def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 100): + def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 1e4): """ Fine Tune SAM2 on the given data. """ @@ -394,7 +398,7 @@ def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 100): else: self.metric_keys = [ 'prompt_miou', 'cal_mae', 'cal_brier', 'cal_ece', - 'cal_tables', 'AR', 'R@10', 'R@50', 'R@100'] + 'AR', 'R@10', 'R@50', 'R@100'] # Cosine scheduler w/Warmup ---- # warmup_epochs = max(int(0.01 * num_epochs), 1) @@ -464,7 +468,7 @@ def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 100): if (epoch+1) % 1e4 == 0: metrics['val'] = self.amg_param_tuner() else: - metrics['val'] = self.validate_step() + metrics['val'] = self.validate_step(best_metric=best_metric) metrics['train'] = losses # Print Only on Rank 0 diff --git a/saber/main.py b/saber/main.py index cb700bd..68c7601 100644 --- a/saber/main.py +++ b/saber/main.py @@ -9,7 +9,6 @@ from saber.gui.base.zarr_gui import gui gui_available = True except Exception as e: - print(f"GUI is not available: {e}") gui_available = False @click.group() From d0ae3d260bb157bccc57ca876eb924023da11dc9 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 1 Oct 2025 18:02:30 +0000 Subject: [PATCH 13/15] checkpoint --- saber/classifier/datasets/augment.py | 1 + saber/finetune/trainer.py | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/saber/classifier/datasets/augment.py b/saber/classifier/datasets/augment.py index c339c20..da13603 100755 --- a/saber/classifier/datasets/augment.py +++ b/saber/classifier/datasets/augment.py @@ -53,6 +53,7 @@ def get_finetune_transforms(target_size=(1024,1024)): RandomOrder([ RandRotate90d(keys=["image", "mask"], prob=0.5, spatial_axes=[0, 1]), RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=0), + RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=1), RandScaleIntensityd(keys="image", prob=0.5, factors=(0.85, 1.15)), RandShiftIntensityd(keys="image", prob=0.5, offsets=(-0.15, 0.15)), RandAdjustContrastd(keys="image", prob=0.5, gamma=(0.85, 1.15)), diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py index 1448b45..2b060e8 100644 --- a/saber/finetune/trainer.py +++ b/saber/finetune/trainer.py @@ -17,15 +17,17 @@ def __init__(self, predictor, train_loader, val_loader, seed=42): # Two parameter groups for different LRs (optional) params = [ - {"params": [p for p in self.predictor.model.sam_mask_decoder.parameters() if p.requires_grad], - "lr": 1e-4}, - {"params": [p for p in self.predictor.model.sam_prompt_encoder.parameters() if p.requires_grad], - "lr": 5e-5}, + # {"params": [p for p in self.predictor.model.image_encoder.parameters() + # if p.requires_grad], "lr": 6e-5}, + {"params": [p for p in self.predictor.model.sam_mask_decoder.parameters() + if p.requires_grad], "lr": 1e-4}, # 1e-4 + {"params": [p for p in self.predictor.model.sam_prompt_encoder.parameters() + if p.requires_grad], "lr": 5e-5}, # 5e-5 ] # Initialize the optimizer and dataloaders self.num_gpus = torch.cuda.device_count() - optimizer = torch.optim.AdamW(params, weight_decay=1e-5) + optimizer = torch.optim.AdamW(params, weight_decay=1e-3) if self.num_gpus > 1: self.fabric = fabric.Fabric(accelerator="cuda", strategy="ddp", devices=self.num_gpus) self.fabric.launch() @@ -48,7 +50,7 @@ def _autocast(): self.train_loader, self.val_loader = train_loader, val_loader # Initialize the loss function - self.focal_alpha = 0.25 + self.focal_alpha = 0.5 self.focal_gamma = 2.0 self.supervise_all_iou = True self.iou_use_l1_loss = True @@ -60,7 +62,7 @@ def _autocast(): points_per_batch=128, pred_iou_thresh=0.7, stability_score_thresh=0.7, - stability_score_offset=0.0, + stability_score_offset=0.5, crop_n_layers=0, crop_n_points_downscale_factor=2, box_nms_thresh=0.6, @@ -329,7 +331,7 @@ def validate_step(self, max_images=float('inf'), best_metric='ABIoU'): self.predictor, # predictor or predictor.model (your function supports either) batch["images"], # list[H×W×3] or list[H×W] batch["masks"], # list[list[H×W]] - top_k=20, + top_k=25, device=self.device, autocast_ctx=self.autocast, amg_kwargs=self.amg_kwargs, @@ -385,7 +387,7 @@ def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 1e4): # Initialize the loss function self.loss_fn = MultiMaskIoULoss( - weight_dict={"loss_mask": 10.0, "loss_dice": 1.0, "loss_iou": 1.0}, + weight_dict={"loss_mask": 20.0, "loss_dice": 1.0, "loss_iou": 1.0}, focal_alpha=self.focal_alpha, focal_gamma=self.focal_gamma, supervise_all_iou=self.supervise_all_iou, From a5e05419a84948c5d784cad5f02ba803c04630b1 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 12 Oct 2025 18:46:17 +0000 Subject: [PATCH 14/15] continue working on drafts --- saber/finetune/README.md | 0 saber/finetune/abiou.py | 2 +- saber/finetune/cli.py | 12 ++++ saber/finetune/dataset.py | 127 ++++++++++++++------------------------ saber/finetune/losses.py | 99 +++++++++++------------------ saber/finetune/prep.py | 74 +++++++++++++--------- saber/finetune/train.py | 10 +-- saber/finetune/trainer.py | 13 ++-- saber/main.py | 2 +- 9 files changed, 154 insertions(+), 185 deletions(-) create mode 100644 saber/finetune/README.md create mode 100644 saber/finetune/cli.py diff --git a/saber/finetune/README.md b/saber/finetune/README.md new file mode 100644 index 0000000..e69de29 diff --git a/saber/finetune/abiou.py b/saber/finetune/abiou.py index 3657927..9b73d50 100644 --- a/saber/finetune/abiou.py +++ b/saber/finetune/abiou.py @@ -204,7 +204,7 @@ def automask_metrics( gt_masks_per_image: List[List[Union[np.ndarray, torch.Tensor]]], # per-image list of HxW masks *, amg_kwargs: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = 20, + top_k: Optional[int] = 200, device: Optional[torch.device] = None, autocast_ctx: Optional[Callable[[], Any]] = None, downsample_factor: int = 1, diff --git a/saber/finetune/cli.py b/saber/finetune/cli.py new file mode 100644 index 0000000..8c9aa8b --- /dev/null +++ b/saber/finetune/cli.py @@ -0,0 +1,12 @@ +from saber.finetune.train import finetune +from saber.finetune.prep import main as prep +import click + +@click.group(name="finetune") +def finetune_routines(): + """Routines for finetuning SAM2 on New Modalities.""" + pass + +# Add subcommands to the group +finetune_routines.add_command(finetune) +finetune_routines.add_command(prep) diff --git a/saber/finetune/dataset.py b/saber/finetune/dataset.py index b057d00..1da3412 100644 --- a/saber/finetune/dataset.py +++ b/saber/finetune/dataset.py @@ -13,8 +13,8 @@ def __init__(self, tomogram_zarr_path: str = None, fib_zarr_path: str = None, transform = None, - slabs_per_volume_per_epoch: int = 10, - slices_per_fib_per_epoch: int = 5, + num_slabs: int = 50, + num_slices: int = 50, slab_thickness: int = 5, seed: int = 42, shuffle: bool = True): @@ -23,19 +23,19 @@ def __init__(self, tomogram_zarr_path: Path to the tomogram zarr store fib_zarr_path: Path to the fib zarr store transform: Transform to apply to the data - slabs_per_volume_per_epoch: Number of slabs per volume per epoch - slices_per_fib_per_epoch: Number of slices per fib per epoch + num_slabs: Number of slabs per volume per epoch + num_slices: Number of slices per fib per epoch slab_thickness: Thickness of the slab """ # Slabs per Epoch self.slab_thickness = slab_thickness - self.slabs_per_volume_per_epoch = slabs_per_volume_per_epoch - self.slices_per_fib_per_epoch = slices_per_fib_per_epoch + self.num_slabs = num_slabs + self.num_slices = num_slices # Grid and Positive Points for AutoMaskGenerator self.points_per_side = 32 - self.min_area = 0.001 + self.min_pixels = 1e3 self.k_min = 50 self.k_max = 100 self.transform = transform @@ -76,7 +76,13 @@ def __init__(self, self.n_fib_volumes = len(self.fib_keys) self.fib_shapes = {} for i, key in enumerate(self.fib_keys): - self.fib_shapes[i] = self.fib_store[key]['0'].shape + try: + self.fib_shapes[i] = self._estimate_zrange(key, source='fib') + except Exception as e: + print(f"Error estimating zrange for fib {key}: {e}") + # remove key from fib_keys + self.fib_keys.remove(key) + self.n_fib_volumes = len(self.fib_keys) else: self.n_fib_volumes = 0 self.fib_shapes = {} @@ -93,12 +99,12 @@ def __init__(self, self.tomogram_samples = [] self.fib_samples = [] self._prev_tomogram_samples = [] - self._prev_fib_samples = [] + self._prev_fib_samples = [] # First sampling epoch self.resample_epoch() - def _estimate_zrange(self, key, band=(0.3, 0.7), threshold=0): + def _estimate_zrange(self, key, band=(0.3, 0.7), source = 'tomo', threshold=0): """ Returns (z_min, z_max) inclusive bounds for valid slab centers where there is some foreground in the labels. @@ -106,9 +112,15 @@ def _estimate_zrange(self, key, band=(0.3, 0.7), threshold=0): - band: fraction of Z to consider (lo, hi) """ - nz = self.tomogram_store[key]['0'].shape[0] + if source == 'tomo': + nz = self.tomogram_store[key]['0'].shape[0] + elif source == 'fib': + nz = self.fib_store[key]['0'].shape[0] min_offset, max_offset = int(nz * band[0]), int(nz * band[1]) - mask = self.tomogram_store[key]['labels/0'][min_offset:max_offset,] + if source == 'tomo': + mask = self.tomogram_store[key]['labels/0'][min_offset:max_offset,] + elif source == 'fib': + mask = self.fib_store[key]['labels/0'][min_offset:max_offset,] vals = mask.sum(axis=(1,2)) vals = np.nonzero(vals)[0] max_val = vals.max() - self.slab_thickness + min_offset @@ -119,7 +131,6 @@ def _compute_indices( self, D: int, N: int, - z_bounds: Optional[Tuple[Union[int, float], Union[int, float]]] = None, ) -> np.ndarray: """ Precompute N slice indices for a volume of depth D. @@ -140,18 +151,9 @@ def _compute_indices( center_bias = 0.9 sigma = 0.2 + z0, z1 = 0, D - 1 # --- Resolve bounds --- - if z_bounds is None: - z0, z1 = 0, D - 1 - else: - a, b = z_bounds - if isinstance(a, float) or isinstance(b, float): - # treat as fractions in [0,1] - a = int(np.floor(np.clip(a, 0.0, 1.0) * (D - 1))) - b = int(np.ceil (np.clip(b, 0.0, 1.0) * (D - 1))) - z0 = int(max(0, min(a, b))) - z1 = int(min(D - 1, max(a, b))) if z1 < z0: # empty window → fall back to full range z0, z1 = 0, D - 1 @@ -194,17 +196,14 @@ def resample_epoch(self): # Sample random slabs from each tomogram if self.has_tomogram: - print(f"Re-Sampling {self.slabs_per_volume_per_epoch} slabs from {self.n_tomogram_volumes} tomograms") + print(f"Re-Sampling {self.num_slabs} slabs from {self.n_tomogram_volumes} tomograms") new_tomo_samples = [] - for vol_idx in range(self.n_tomogram_volumes): + for vol_idx in tqdm(range(self.n_tomogram_volumes), desc="Sampling slabs from tomograms"): # Sample Indices in [0, D-1], center-biased within the *window* z_min, z_max = self.tomo_shapes[vol_idx] - z_rel = self._compute_indices( - z_max - z_min, - self.slabs_per_volume_per_epoch, - z_bounds=(z_min, z_max) - ) + D = (z_max - z_min + 1) + z_rel = self._compute_indices(D, self.num_slabs) # Shift to absolute z and add to samples z_positions = z_rel + z_min @@ -217,21 +216,20 @@ def resample_epoch(self): # Sample random slices from each FIB volume if self.has_fib: - print(f"Re-Sampling {self.slices_per_fib_per_epoch} slices from {self.n_fib_volumes} FIB volumes") + print(f"Re-Sampling {self.num_slices} slices from {self.n_fib_volumes} FIB volumes") new_fib_samples = [] - for fib_idx in range(self.n_fib_volumes): - fib_shape = self.fib_shapes[fib_idx] + for fib_idx in tqdm(range(self.n_fib_volumes), desc="Sampling slices from FIB volumes"): # Sample random z positions from this FIB volume - z_positions = self._compute_indices( - fib_shape[0], - self.slices_per_fib_per_epoch, - z_bounds=(0, fib_shape[0]) - ) - + z_min, z_max = self.fib_shapes[fib_idx] + D = (z_max - z_min + 1) + z_rel = self._compute_indices( + D, self.num_slices ) + # Shift to absolute z and add to samples + z_positions = z_rel + z_min for z_pos in z_positions: - new_fib_samples.append((fib_idx, z_pos)) + new_fib_samples.append((fib_idx, z_pos)) + self.fib_samples = new_fib_samples - self.fib_samples = self._update_samples(self.fib_samples, new_fib_samples) if self.shuffle: self._rng.shuffle(self.fib_samples) # Shuffle samples # Set epoch length @@ -273,11 +271,9 @@ def _get_fib_item(self, idx): key = self.fib_keys[fib_idx] # Load FIB image and segmentation - image = self.fib_store[key]['0'][z_pos,].astype(np.float32) - image_2d = preprocessing.proprocess(image) - seg_2d = self.fib_store[key]['labels/0'][z_pos,] - - return self._package_image_item(image_2d, seg_2d) + image = self.fib_store[key]['0'][z_pos,].astype(np.float32) # HxW + seg_2d = self.fib_store[key]['labels/0'][z_pos,] # HxW + return self._package_image_item(image, seg_2d) def _gen_grid_points(self, h: int, w: int) -> np.ndarray: """ @@ -402,12 +398,10 @@ def _package_image_item(self, # Apply transforms to image and segmentation if self.transform: sample = self.transform({'image': image_2d, 'mask': segmentation}) - image_2d, segmentation = sample['image'], sample['mask'] + image_2d, segmentation = sample['image'], sample['mask'] # Get image and segmentation shapes h, w = segmentation.shape - min_pixels = 0 - # min_pixels = int(self.min_area * h * w) # which instances to train on for this image grid_points = self._gen_grid_points(h, w) @@ -415,7 +409,7 @@ def _package_image_item(self, masks_t, points_t, labels_t, boxes_t = [], [], [], [] for iid in inst_ids: - comps = helper.components_for_id(segmentation, iid, min_pixels) + comps = helper.components_for_id(segmentation, iid, self.min_pixels) for comp in comps: # box from this component box = helper.mask_to_box(comp) @@ -450,35 +444,4 @@ def _package_image_item(self, "points": points_t, # list[#p x 2] float32 (xy) "labels": labels_t, # list[#p] all ones "boxes": boxes_t, # list[4] (x0,y0,x1,y1) - } - - # def _update_samples(self, old, new): - # """ - # Return a mixed list with size == len(new): - # - keep = min(round(len(new)*keep_fraction), len(old)) from 'old' - # - add = len(new) - keep from 'new' - # """ - # target = len(new) - # if target == 0: - # return [] - - # # choose keep set from *old* list, new set from *new* list - # keep = min(int(round(target * self.keep_fraction)), len(old)) - # add = target - keep - - # # keep set from *old* list - # if keep > 0: - # keep_idx = self._rng.choice(len(old), size=keep, replace=False) - # kept = [old[i] for i in keep_idx] - # else: - # kept = [] - - # # add set from *new* list - # if add > 0: - # new_idx = self._rng.choice(len(new), size=add, replace=False) - # added = [new[i] for i in new_idx] - # else: - # added = [] - - # # return mixed list - # return kept + added \ No newline at end of file + } \ No newline at end of file diff --git a/saber/finetune/losses.py b/saber/finetune/losses.py index 3e3fe4a..615866c 100644 --- a/saber/finetune/losses.py +++ b/saber/finetune/losses.py @@ -16,101 +16,72 @@ def focal_loss_from_logits(logits, targets, alpha=0.25, gamma=2.0): return (w * ((1 - pt) ** gamma) * ce).mean(dim=(1, 2)) class MultiMaskIoULoss(nn.Module): - """ - General loss for multi-mask predictions with IoU calibration. - Designed for SAM/SAM2 fine-tuning with AMG. - """ - - def __init__(self, - weight_dict: dict, - focal_alpha=0.25, - focal_gamma=2.0, - supervise_all_iou=False, - iou_use_l1_loss=True): + def __init__(self, weight_dict, focal_alpha=0.25, focal_gamma=2.0, + supervise_all_iou=True, iou_use_l1_loss=True, all_iou_weight=0.1): super().__init__() self.weight_dict = weight_dict - assert "loss_mask" in weight_dict - assert "loss_dice" in weight_dict - assert "loss_iou" in weight_dict self.focal_alpha = focal_alpha self.focal_gamma = focal_gamma self.supervise_all_iou = supervise_all_iou + self.all_iou_weight = all_iou_weight self.iou_use_l1_loss = iou_use_l1_loss def forward(self, prd_masks, prd_scores, gt_masks): - """ - Args - ---- - prd_masks: [N, K, H, W] logits from decoder - prd_scores: [N, K] predicted IoU logits (will be sigmoided here) - gt_masks: [N, H, W] float {0,1} - """ - + # prd_masks: [N,K,H,W] (logits), prd_scores: [N,K] (IoU logits), gt_masks: [N,H,W] (0/1) device = prd_masks.device N, K, H, W = prd_masks.shape - gt_masks = gt_masks.to(prd_masks.dtype) # [N,H,W] + gt_masks = gt_masks.to(prd_masks.dtype) - # ---- Compute hard predictions and true IoU per proposal (no grad) ---------- - with torch.no_grad(): - pred_bin = (prd_masks > 0.0).to(gt_masks.dtype) # [N,K,H,W] - gt_k = gt_masks[:, None].expand_as(pred_bin) # [N,K,H,W] - inter = (pred_bin * gt_k).sum(dim=(2, 3)) # [N,K] - union = (pred_bin + gt_k - pred_bin * gt_k).sum(dim=(2, 3)).clamp_min(1e-6) - true_iou_k = inter / union # [N,K] in [0,1] - - # ---- Per-proposal segmentation loss (focal + dice), select argmin ---------- - gt_rep = gt_masks.repeat_interleave(K, dim=0) # [N*K,H,W] + # --- compute per-proposal seg losses (no reduction) --- + gt_rep = gt_masks.repeat_interleave(K, dim=0) # [N*K,H,W] focal_per_k = focal_loss_from_logits( prd_masks.view(N*K, H, W), gt_rep, alpha=self.focal_alpha, gamma=self.focal_gamma - ).view(N, K) # [N,K] + ).view(N, K) # [N,K] dice_per_k = dice_loss_from_logits( prd_masks.view(N*K, H, W), gt_rep - ).view(N, K) # [N,K] - - seg_loss_per_k = focal_per_k + dice_per_k # [N,K] - # best_idx = seg_loss_per_k.argmin(dim=1) # [N] choose lowest seg loss - best_idx = true_iou_k.argmax(dim=1) # [N] choose highest IoU + ).view(N, K) # [N,K] + seg_loss_per_k = focal_per_k + dice_per_k + # --- choose the slot by lowest seg loss (SAM2-style) --- + best_idx = seg_loss_per_k.argmin(dim=1) # [N] row = torch.arange(N, device=device) - logits_star = prd_masks[row, best_idx] # [N,H,W] - true_iou_star = true_iou_k[row, best_idx].detach() # [N] + logits_star = prd_masks[row, best_idx] # [N,H,W] - # ---- Segmentation losses on the chosen proposal ---------------------------- - l_focal = focal_loss_from_logits( - logits_star, gt_masks, - alpha=self.focal_alpha, gamma=self.focal_gamma - ).mean() # scalar + # --- actual IoU per proposal (stop grad) --- + with torch.no_grad(): + pred_bin = (prd_masks > 0).to(gt_masks.dtype) # [N,K,H,W] + gt_k = gt_masks[:, None].expand_as(pred_bin) # [N,K,H,W] + inter = (pred_bin * gt_k).sum(dim=(2,3)) + union = (pred_bin + gt_k - pred_bin*gt_k).sum(dim=(2,3)).clamp_min(1e-6) + true_iou_k = inter / union # [N,K] + true_iou_star = true_iou_k[row, best_idx] + + # --- seg losses on the chosen slot only --- + l_focal = focal_loss_from_logits(logits_star, gt_masks, + alpha=self.focal_alpha, gamma=self.focal_gamma).mean() l_dice = dice_loss_from_logits(logits_star, gt_masks).mean() - # ---- IoU head regression (sigmoid + L1 by default) ------------------------- - pred_iou = prd_scores.sigmoid() # [N,K] in [0,1] - + # --- IoU head regression --- + pred_iou = prd_scores.sigmoid() if self.iou_use_l1_loss: l_iou = F.l1_loss(pred_iou[row, best_idx], true_iou_star) else: l_iou = F.mse_loss(pred_iou[row, best_idx], true_iou_star) - # Optional: supervise IoU for *all* proposals with small weight if self.supervise_all_iou: if self.iou_use_l1_loss: - l_iou_all = F.l1_loss(pred_iou, true_iou_k.detach()) + l_iou_all = F.l1_loss(pred_iou, true_iou_k) else: - l_iou_all = F.mse_loss(pred_iou, true_iou_k.detach()) - l_iou = l_iou + 0.1 * l_iou_all # tune 0.05–0.2 if needed + l_iou_all = F.mse_loss(pred_iou, true_iou_k) + l_iou = l_iou + self.all_iou_weight * l_iou_all - # ---- Weighted sum ---------------------------------------------------------- + # --- weighted sum --- loss_mask = l_focal loss_dice = l_dice loss_iou = l_iou + total = (self.weight_dict["loss_mask"] * loss_mask + + self.weight_dict["loss_dice"] * loss_dice + + self.weight_dict["loss_iou"] * loss_iou) - total_loss = (self.weight_dict["loss_mask"] * loss_mask + - self.weight_dict["loss_dice"] * loss_dice + - self.weight_dict["loss_iou"] * loss_iou) - - return { - "loss_mask": loss_mask, - "loss_dice": loss_dice, - "loss_iou": loss_iou, - "loss_total": total_loss, - } \ No newline at end of file + return {"loss_mask": loss_mask, "loss_dice": loss_dice, "loss_iou": loss_iou, "loss_total": total} \ No newline at end of file diff --git a/saber/finetune/prep.py b/saber/finetune/prep.py index ae74945..4d58ab3 100644 --- a/saber/finetune/prep.py +++ b/saber/finetune/prep.py @@ -19,7 +19,8 @@ def process_run_3d_simple( tomo_alg: str, organelle_names: List[str], min_component_volume: int = 100, - user_id: Optional[str] = None + user_id: Optional[str] = None, + values: Dict[str, int] = {}, ) -> Tuple[np.ndarray, np.ndarray, Dict[str, str]]: """ Process a single run in full 3D with simple volume-based sorting. @@ -51,10 +52,17 @@ def process_run_3d_simple( click.echo(f" Processing {organelle_name}...") seg = readers.segmentation(run, voxel_spacing, organelle_name, user_id=user_id) - if seg.shape != volume_3d.shape: - temp_seg_3d = np.zeros_like(seg, dtype=np.uint32) - - if seg is not None: + if seg is None: + click.echo(" No segmentation found") + continue + elif seg.shape != volume_3d.shape: + print(f"[Warning, {run.name}] Segmentation shape {seg.shape} does not match tomogram shape {volume_3d.shape}") + continue + elif values: + offset = values[organelle_name] + temp_seg_3d[seg > 0.5] = offset + components.append((organelle_name, offset, np.sum(seg > 0.5))) + else: # Convert to binary and separate connected components binary_mask = (seg > 0.5).astype(np.uint8) labeled_mask = label(binary_mask, connectivity=3) @@ -67,8 +75,6 @@ def process_run_3d_simple( temp_seg_3d[labeled_mask == label_val] = offset components.append((organelle_name, offset, vol)) offset += 1 - else: - click.echo(" No segmentation found") except Exception as e: click.echo(f" Error processing {organelle_name}: {e}") @@ -78,8 +84,9 @@ def process_run_3d_simple( # Create final segmentation with remapped labels and the mapping dictionary seg_3d = np.zeros_like(temp_seg_3d, dtype=np.uint16) id_to_organelle: Dict[str, str] = {} - for new_label, (organelle_name, old_label, _volume) in enumerate(components, start=1): + # Only remap if we have predefined values, else keep sequential values + if values: new_label = old_label seg_3d[temp_seg_3d == old_label] = new_label id_to_organelle[str(new_label)] = organelle_name @@ -92,10 +99,10 @@ def convert_copick_to_3d_zarr( output_json_path: Optional[str], voxel_spacing: float, tomo_alg: str, - specific_runs: Optional[List[str]], + specific_runs: str, min_component_volume: int, - compress: bool, user_id: Optional[str], + keep_labels: bool ): """ Convert copick data to 3D zarr format with JSON segmentation mapping. @@ -111,12 +118,18 @@ def convert_copick_to_3d_zarr( organelle_names = [x for x in organelle_names if "membrane" not in x] click.echo(f"Found organelles: {organelle_names}") + # Prepare organelle values if keeping labels + if keep_labels: + organelle_values = {obj.name: obj.label for obj in root.pickable_objects if not obj.is_particle} + else: + organelle_values = {} + # Set default JSON output path if output_json_path is None: output_json_path = output_zarr_path.replace(".zarr", "_mapping.json") # Initialize zarr store - compressor = zarr.Blosc(cname="zstd", clevel=2) if compress else None + compressor = zarr.Blosc(cname="zstd", clevel=2) store = zarr.DirectoryStore(output_zarr_path) zroot = zarr.group(store=store, overwrite=True) @@ -124,7 +137,7 @@ def convert_copick_to_3d_zarr( master_mapping: Dict[str, Dict[str, str]] = {} # Determine which runs to process - runs_to_process = specific_runs if specific_runs else [run.name for run in root.runs] + runs_to_process = specific_runs.split(",") if specific_runs else [run.name for run in root.runs] for run_name in tqdm(runs_to_process, desc="Processing runs"): click.echo(f"\nProcessing run: {run_name}") @@ -139,6 +152,7 @@ def convert_copick_to_3d_zarr( organelle_names=organelle_names, min_component_volume=min_component_volume, user_id=user_id, + values = organelle_values ) # Create zarr group for this run @@ -176,8 +190,11 @@ def convert_copick_to_3d_zarr( continue # Save master JSON mapping - with open(output_json_path, "w") as f: - json.dump(master_mapping, f, indent=2) + if keep_labels: + zroot.attrs['label_values'] = organelle_values + else: + with open(output_json_path, "w") as f: + json.dump(master_mapping, f, indent=2) click.echo("\n🎉 Conversion complete!") click.echo(f"📁 Zarr output: {output_zarr_path}") @@ -212,7 +229,7 @@ def load_3d_zarr_data(zarr_path: str, run_name: str) -> Tuple[np.ndarray, np.nda return volume, labels, id_to_organelle -@click.command(context_settings={"show_default": True}) +@click.command(context_settings={"show_default": True}, name='prep') @click.option( "--config", "config_path", @@ -248,15 +265,16 @@ def load_3d_zarr_data(zarr_path: str, run_name: str) -> Tuple[np.ndarray, np.nda help="Tomogram algorithm to use for processing.", ) @click.option( - "--specific-run", + "--runIDs", "specific_runs", - multiple=True, - help="Process only specific runs. Repeat this option for multiple runs.", + type=str, + default=None, + help="Process only specific runs. Provide a comma-separated list of runIDs.", ) @click.option( "--min-component-volume", type=int, - default=10000, + default=1e3, help="Minimum connected-component volume (in voxels).", ) @click.option( @@ -266,10 +284,10 @@ def load_3d_zarr_data(zarr_path: str, run_name: str) -> Tuple[np.ndarray, np.nda help="UserID for accessing segmentation.", ) @click.option( - "--no-compress", - is_flag=True, + "--keep-labels", + type=bool, default=False, - help="Disable compression for zarr storage.", + help="Save Segmentations with Values Defined from the Copick Configuration File", ) def main( config_path: str, @@ -277,11 +295,12 @@ def main( output_json_path: Optional[str], voxel_spacing: float, tomo_alg: str, - specific_runs: Optional[List[str]], + specific_runs: str, min_component_volume: int, user_id: Optional[str], - no_compress: bool, -): + keep_labels: bool, + ): + """Convert copick data to 3D zarr format with JSON segmentation mapping.""" convert_copick_to_3d_zarr( config_path=config_path, @@ -289,10 +308,9 @@ def main( output_json_path=output_json_path, voxel_spacing=voxel_spacing, tomo_alg=tomo_alg, - specific_runs=list(specific_runs) if specific_runs else None, + specific_runs=specific_runs, min_component_volume=min_component_volume, - compress=not no_compress, - user_id=user_id, + keep_labels=keep_labels, user_id=user_id, ) diff --git a/saber/finetune/train.py b/saber/finetune/train.py index 2c9746b..a73699f 100644 --- a/saber/finetune/train.py +++ b/saber/finetune/train.py @@ -8,7 +8,7 @@ from sam2.build_sam import build_sam2 from saber import pretrained_weights from saber.utils import io -import click +import click, os def finetune_sam2( tomo_train: str = None, @@ -33,14 +33,14 @@ def finetune_sam2( # Load data loaders train_loader = DataLoader( AutoMaskDataset( - tomo_train, fib_train, transform=get_finetune_transforms(), - slabs_per_volume_per_epoch=20 ), + tomo_train, fib_train, + transform=get_finetune_transforms()), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate_autoseg ) val_loader = DataLoader( AutoMaskDataset( - tomo_val, fib_val, slabs_per_volume_per_epoch=15 ), + tomo_val, fib_val, shuffle=False ), num_workers=4, pin_memory=True, collate_fn=collate_autoseg, batch_size=batch_size, shuffle=False ) if (tomo_val or fib_val) else train_loader @@ -49,7 +49,7 @@ def finetune_sam2( # trainer.train( num_epochs, best_metric='AR' ) trainer.train( num_epochs ) -@click.command() +@click.command(context_settings={"show_default": True}, name='run') @sam2_inputs @click.option("--fib-train", type=str, help="Path to train Zarr") @click.option("--fib-val", type=str, help="Path to val Zarr") diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py index 2b060e8..bc18aae 100644 --- a/saber/finetune/trainer.py +++ b/saber/finetune/trainer.py @@ -9,6 +9,10 @@ from lightning import fabric from tqdm import tqdm +# Tests: +# Weights - +# Soften Mask Prompts Logits - ±3 + class SAM2FinetuneTrainer: def __init__(self, predictor, train_loader, val_loader, seed=42): @@ -66,7 +70,7 @@ def _autocast(): crop_n_layers=0, crop_n_points_downscale_factor=2, box_nms_thresh=0.6, - use_m2m=True, + use_m2m=False, multimask_output=False, ) self.nAMGtrials = 10 @@ -102,7 +106,8 @@ def _stack_image_embeddings_from_predictor(self): hr_feats = [torch.stack([lvl[b] for b in range(B)], dim=0).to(self.device) for lvl in hr] return image_embeds, hr_feats - def _determine_sampling(self, N, p_points=0.5, p_box=0.15, p_mask=0.2, p_mask_box=0.15): + # p_points=0.5, p_box=0.15, p_mask=0.2, p_mask_box=0.15 + def _determine_sampling(self, N, p_points=0.4, p_box=0.15, p_mask=0.25, p_mask_box=0.25): """ Decide which prompt combo each instance uses. Returns a list[int] of length N with codes: @@ -222,7 +227,7 @@ def forward_step(self, batch): # 5) mask logits (+/-6) gt_masks_bin = torch.stack([m.to(torch.float32) for m in gt_all], dim=0).to(self.device) - mask_logits_full = (gt_masks_bin * 2.0 - 1.0) * 6.0 + mask_logits_full = (gt_masks_bin * 2.0 - 1.0) * 3.0 # Before it was ±6 # 6) build per-instance prompts pts_pad, lbl_pad, boxes, mask_logits = self._process_inputs( @@ -331,7 +336,7 @@ def validate_step(self, max_images=float('inf'), best_metric='ABIoU'): self.predictor, # predictor or predictor.model (your function supports either) batch["images"], # list[H×W×3] or list[H×W] batch["masks"], # list[list[H×W]] - top_k=25, + top_k=150, device=self.device, autocast_ctx=self.autocast, amg_kwargs=self.amg_kwargs, diff --git a/saber/main.py b/saber/main.py index 68c7601..768d888 100644 --- a/saber/main.py +++ b/saber/main.py @@ -1,9 +1,9 @@ from saber.classifier.cli import classifier_routines as classifier from saber.entry_points.run_low_pass_filter import cli as filter3d from saber.entry_points.segment_methods import methods as segment +from saber.finetune.cli import finetune_routines as finetune from saber.analysis.analysis_cli import methods as analysis from saber.entry_points.run_analysis import cli as save -from saber.finetune.train import finetune import click try: from saber.gui.base.zarr_gui import gui From 791ab29402dcb97c91f41de6ae3367bd75a38b10 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 04:21:41 +0000 Subject: [PATCH 15/15] near working impelmentation for sam2 --- saber/finetune/dataset.py | 40 ++++- saber/finetune/helper.py | 72 +++++++- saber/finetune/losses.py | 208 +++++++++++++++-------- saber/finetune/trainer.py | 335 +++++++++++++++++++++----------------- 4 files changed, 430 insertions(+), 225 deletions(-) diff --git a/saber/finetune/dataset.py b/saber/finetune/dataset.py index 1da3412..b6207b7 100644 --- a/saber/finetune/dataset.py +++ b/saber/finetune/dataset.py @@ -41,6 +41,8 @@ def __init__(self, self.transform = transform self.keep_fraction = 0.5 self.shuffle = shuffle + self.max_pos_pts = 8 + self.K_choices = [2, 2, 4, 4, 8] # Check if both data types are available if tomogram_zarr_path is None and fib_zarr_path is None: @@ -304,7 +306,7 @@ def _sample_points_in_mask( grid_points: np.ndarray, shape: tuple[int, int], jitter_px: float = 1.0, - k_cap: int = 300, + k_cap: int = 100, boundary_frac: float = 0.35, ) -> np.ndarray: """ @@ -340,7 +342,7 @@ def _sample_points_in_mask( # target k ~ c * area but capped area = float(comp_b.sum()) k_target = int( - min( k_cap, max(24, area * 0.12) ) + min( k_cap, max(12, area * 0.1) ) ) kb = int(boundary_frac * k_target) @@ -416,15 +418,37 @@ def _package_image_item(self, if box is None: continue - # sample clicks from this component (NOT the full instance) - # pts = helper.sample_positive_points(comp, k=self.k_pos) - pts = self._sample_points_in_mask(comp, grid_points, shape=(h, w)) - if pts.shape[0] == 0: - continue + K = self.K_choices[self._rng.randint(0, len(self.K_choices))] + pts_pos = self._sample_points_in_mask(comp, grid_points, (h, w)) + if pts_pos.shape[0] < 3: + continue + elif pts_pos.shape[0] > K: + sel = self._rng.choice(pts_pos.shape[0], size=K, replace=False) + pts_pos = pts_pos[sel] + + # negatives just outside this component, but not inside any *other* instance + use_negs = (self._rng.rand() < 0.5) + if use_negs: + other_inst = ((segmentation > 0) & (~comp.astype(bool))) + neg_ring = self._sample_negative_ring( + comp, other_inst=other_inst, ring=3, + max_neg=min(4, int(0.25 * len(pts_pos))), shape=(h, w) + ) + + pts = np.concatenate([pts_pos, neg_ring], 0) + lbl = np.concatenate([np.ones(len(pts_pos), np.float32), + np.zeros(len(neg_ring), np.float32)], 0) + else: + pts, lbl = pts_pos, np.ones((len(pts_pos),), np.float32) + + # shuffle to avoid positional bias + if pts.shape[0] > 1: + order = self._rng.permutation(pts.shape[0]) + pts, lbl = pts[order], lbl[order] masks_t.append(torch.from_numpy(comp.astype(np.float32))) points_t.append(torch.from_numpy(pts.astype(np.float32))) - labels_t.append(torch.from_numpy(np.ones((pts.shape[0],), dtype=np.float32))) + labels_t.append(torch.from_numpy(lbl.astype(np.float32))) # 1=pos, 0=neg boxes_t.append(torch.from_numpy(box.astype(np.float32))) # fallback to a harmless dummy if nothing was hit by the grid (keeps loader stable) diff --git a/saber/finetune/helper.py b/saber/finetune/helper.py index f8a9597..e58e5d0 100644 --- a/saber/finetune/helper.py +++ b/saber/finetune/helper.py @@ -311,4 +311,74 @@ def _orig_hw_tuple(pred): if inst_masks: gt_masks = torch.stack([torch.as_tensor(m, device=device).float() for m in inst_masks], dim=0) - return prd_masks, prd_scores, gt_masks \ No newline at end of file + return prd_masks, prd_scores, gt_masks + +@torch.no_grad() +def _binary_iou(a: torch.Tensor, b: torch.Tensor, eps=1e-6) -> torch.Tensor: + # a,b: (N,H,W) boolean/0-1 + inter = (a & b).float().sum(dim=(1,2)) + uni = (a | b).float().sum(dim=(1,2)).clamp_min(eps) + return inter / uni + +@torch.no_grad() +def stability_from_logits(mask_logits: torch.Tensor, delta: float = 0.05) -> torch.Tensor: + """ + mask_logits: (K,H,W) raw logits from the decoder. + We compute IoU between masks thresholded at 0 and +/- delta in logit-space. + Returns: (K,) stability score in [0,1]. + """ + # Note: comparing binary maps at thresholds t1 and t2 is enough; no need to sigmoid. + m_lo = (mask_logits > -delta) # t = -delta + m_hi = (mask_logits > +delta) # t = +delta + # Using IoU between the two gives the 'stability' used by AMG + return _binary_iou(m_lo, m_hi) + +@torch.no_grad() +def dynamic_multimask_predict( + predictor, + image_batch_entry, + points=None, labels=None, box=None, mask_input=None, + stability_thresh: float = 0.98, + stability_delta: float = 0.05, + multimask_k: int = 3, +): + """ + 1) Run single-mask decode (multimask_output=False). + 2) Compute stability; if stable >= thresh, keep single. + 3) Otherwise rerun with multimask_output=True to get K candidates. + + Returns: + prd_masks: (K,H,W) float logits (not sigmoid) **if multimask**, else (1,H,W) + prd_scores: (K,) IoU-head predictions (or (1,)) + aux: dict with {'stability': tensor} + """ + # First pass: single mask + out1 = predictor.predict( + image=image_batch_entry, + point_coords=points, point_labels=labels, + box=box, mask_input=mask_input, + multimask_output=False, # <-- single + return_logits=True, # <-- to compute stability + return_iou_predictions=True, # <-- IoU head (SAM2) + ) + logits1 = out1["masks_logits"][0] # (1,H,W) + iou1 = out1["iou_predictions"][0] # (1,) + stab1 = stability_from_logits(logits1, delta=stability_delta) # (1,) + + if stab1[0].item() >= stability_thresh: + return logits1, iou1, {"stability": stab1} + + # Not stable -> get multimask + outk = predictor.predict( + image=image_batch_entry, + point_coords=points, point_labels=labels, + box=box, mask_input=mask_input, + multimask_output=True, # <-- gated on stability + num_multimask_outputs=multimask_k, # if your API supports it; otherwise ignored + return_logits=True, + return_iou_predictions=True, + ) + logitsk = outk["masks_logits"][0] # (K,H,W) + iouk = outk["iou_predictions"][0] # (K,) + stabk = stability_from_logits(logitsk, delta=stability_delta) # (K,) + return logitsk, iouk, {"stability": stabk} \ No newline at end of file diff --git a/saber/finetune/losses.py b/saber/finetune/losses.py index 615866c..d9e1950 100644 --- a/saber/finetune/losses.py +++ b/saber/finetune/losses.py @@ -9,79 +9,155 @@ def dice_loss_from_logits(logits, targets, eps=1e-6): return 1 - (2 * inter + eps) / (denom + eps) def focal_loss_from_logits(logits, targets, alpha=0.25, gamma=2.0): + # logits, targets: (N,H,W) float with targets in {0,1} ce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none") p = torch.sigmoid(logits) pt = p * targets + (1 - p) * (1 - targets) w = alpha * targets + (1 - alpha) * (1 - targets) - return (w * ((1 - pt) ** gamma) * ce).mean(dim=(1, 2)) + return (w * ((1 - pt) ** gamma) * ce).mean(dim=(1, 2)) # -> (N,) class MultiMaskIoULoss(nn.Module): - def __init__(self, weight_dict, focal_alpha=0.25, focal_gamma=2.0, - supervise_all_iou=True, iou_use_l1_loss=True, all_iou_weight=0.1): + def __init__( + self, + iou_regression: str = "l1", # "l1" or "mse" + supervise_all_iou: bool = False, + all_iou_weight: float = 0.1, + # Weights: match the (dice=20, focal=1, iou=1, class=1) convention + weight_dict: dict = None, + focal_alpha: float = 0.25, + focal_gamma: float = 2.0, + # Objectness head + pred_obj_scores: bool = True, + focal_alpha_obj: float = 0.5, # can be -1 to disable alpha weighting + focal_gamma_obj: float = 0.0, # 0 -> plain BCE on objectness + gate_by_objectness: bool = True, # gate seg/IoU losses when no object + ): super().__init__() - self.weight_dict = weight_dict - self.focal_alpha = focal_alpha - self.focal_gamma = focal_gamma + if weight_dict is None: + weight_dict = {"loss_mask": 1.0, "loss_dice": 20.0, "loss_iou": 1.0, "loss_class": 1.0} + # NOTE: "loss_mask" == focal; "loss_dice" == dice; "loss_class" == objectness + self.focal_weight = weight_dict.get("loss_mask", 1.0) + self.dice_weight = weight_dict.get("loss_dice", 20.0) + self.iou_head_weight = weight_dict.get("loss_iou", 1.0) + self.class_weight = weight_dict.get("loss_class", 1.0) + + self.iou_regression = iou_regression self.supervise_all_iou = supervise_all_iou self.all_iou_weight = all_iou_weight - self.iou_use_l1_loss = iou_use_l1_loss - - def forward(self, prd_masks, prd_scores, gt_masks): - # prd_masks: [N,K,H,W] (logits), prd_scores: [N,K] (IoU logits), gt_masks: [N,H,W] (0/1) - device = prd_masks.device - N, K, H, W = prd_masks.shape - gt_masks = gt_masks.to(prd_masks.dtype) - - # --- compute per-proposal seg losses (no reduction) --- - gt_rep = gt_masks.repeat_interleave(K, dim=0) # [N*K,H,W] - focal_per_k = focal_loss_from_logits( - prd_masks.view(N*K, H, W), gt_rep, - alpha=self.focal_alpha, gamma=self.focal_gamma - ).view(N, K) # [N,K] - dice_per_k = dice_loss_from_logits( - prd_masks.view(N*K, H, W), gt_rep - ).view(N, K) # [N,K] - seg_loss_per_k = focal_per_k + dice_per_k - - # --- choose the slot by lowest seg loss (SAM2-style) --- - best_idx = seg_loss_per_k.argmin(dim=1) # [N] - row = torch.arange(N, device=device) - logits_star = prd_masks[row, best_idx] # [N,H,W] - - # --- actual IoU per proposal (stop grad) --- - with torch.no_grad(): - pred_bin = (prd_masks > 0).to(gt_masks.dtype) # [N,K,H,W] - gt_k = gt_masks[:, None].expand_as(pred_bin) # [N,K,H,W] - inter = (pred_bin * gt_k).sum(dim=(2,3)) - union = (pred_bin + gt_k - pred_bin*gt_k).sum(dim=(2,3)).clamp_min(1e-6) - true_iou_k = inter / union # [N,K] - true_iou_star = true_iou_k[row, best_idx] - - # --- seg losses on the chosen slot only --- - l_focal = focal_loss_from_logits(logits_star, gt_masks, - alpha=self.focal_alpha, gamma=self.focal_gamma).mean() - l_dice = dice_loss_from_logits(logits_star, gt_masks).mean() - - # --- IoU head regression --- - pred_iou = prd_scores.sigmoid() - if self.iou_use_l1_loss: - l_iou = F.l1_loss(pred_iou[row, best_idx], true_iou_star) + self.focal_alpha = focal_alpha + self.focal_gamma = focal_gamma + + self.pred_obj_scores = pred_obj_scores + self.focal_alpha_obj = focal_alpha_obj + self.focal_gamma_obj = focal_gamma_obj + self.gate_by_objectness = gate_by_objectness + + def _reg_loss(self, pred, target): + return F.mse_loss(pred, target) if self.iou_regression == "mse" else F.l1_loss(pred, target) + + @staticmethod + def _binary_iou(bin_mask, gt_bool, eps=1e-6): + inter = (bin_mask & gt_bool).float().sum(dim=(1, 2)) + union = (bin_mask | gt_bool).float().sum(dim=(1, 2)).clamp_min(eps) + return inter / union + + def _objectness_targets(self, gt_mask): + # gt_mask: (H,W) bool/0-1 → scalar 0/1 (any positive pixel?) + return (gt_mask.sum() > 0).float() + + def _objectness_loss(self, logits_scalar, target_scalar): + # logits_scalar: shape () or (1,) + # Use focal-like BCE if gamma>0, else plain BCE + if self.focal_gamma_obj == 0.0 and self.focal_alpha_obj < 0: + return F.binary_cross_entropy_with_logits(logits_scalar, target_scalar) + p = torch.sigmoid(logits_scalar) + ce = F.binary_cross_entropy_with_logits(logits_scalar, target_scalar, reduction="none") + pt = p * target_scalar + (1 - p) * (1 - target_scalar) + if self.focal_alpha_obj >= 0: + w = self.focal_alpha_obj * target_scalar + (1 - self.focal_alpha_obj) * (1 - target_scalar) else: - l_iou = F.mse_loss(pred_iou[row, best_idx], true_iou_star) - - if self.supervise_all_iou: - if self.iou_use_l1_loss: - l_iou_all = F.l1_loss(pred_iou, true_iou_k) - else: - l_iou_all = F.mse_loss(pred_iou, true_iou_k) - l_iou = l_iou + self.all_iou_weight * l_iou_all - - # --- weighted sum --- - loss_mask = l_focal - loss_dice = l_dice - loss_iou = l_iou - total = (self.weight_dict["loss_mask"] * loss_mask - + self.weight_dict["loss_dice"] * loss_dice - + self.weight_dict["loss_iou"] * loss_iou) - - return {"loss_mask": loss_mask, "loss_dice": loss_dice, "loss_iou": loss_iou, "loss_total": total} \ No newline at end of file + w = 1.0 + return (w * ((1 - pt) ** self.focal_gamma_obj) * ce).mean() + + def forward(self, prd_masks_logits, prd_iou_scores, gt_masks, object_score_logits=None): + """ + prd_masks_logits: (N,K,H,W) or (K,H,W) + prd_iou_scores: (N,K) or (K,) + gt_masks: (N,H,W) or (H,W) + object_score_logits: (N,) or (N,1) or scalar per instance (optional) + """ + # ---- normalize shapes to batched form ---- + if prd_masks_logits.dim() == 3: # (K,H,W) -> (1,K,H,W) + prd_masks_logits = prd_masks_logits.unsqueeze(0) + if prd_iou_scores.dim() == 1: # (K,) -> (1,K) + prd_iou_scores = prd_iou_scores.unsqueeze(0) + if gt_masks.dim() == 2: # (H,W) -> (1,H,W) + gt_masks = gt_masks.unsqueeze(0) + + N, K, H, W = prd_masks_logits.shape + gt_bool = gt_masks.bool() # (N,H,W) + gt_float = gt_masks.float() # (N,H,W) + + # ---- IoU per slot (vectorized) ---- + bin_masks = (prd_masks_logits > 0) # (N,K,H,W) + gt_bool_b = gt_bool.unsqueeze(1) # (N,1,H,W) <-- key fix + inter = (bin_masks & gt_bool_b).float().sum(dim=(2,3)) # (N,K) + union = (bin_masks | gt_bool_b).float().sum(dim=(2,3)).clamp_min(1e-6) # (N,K) + true_iou_per_k = inter / union # (N,K) + + # ---- pick best slot per instance ---- + best_ix = torch.argmax(true_iou_per_k, dim=1) # (N,) + idx_4d = best_ix.view(N,1,1,1).expand(N,1,H,W) + best_logits = prd_masks_logits.gather(1, idx_4d).squeeze(1) # (N,H,W) + + # ---- segmentation losses on best slot ---- + dice_per = dice_loss_from_logits(best_logits, gt_float) # (N,) + focal_per = focal_loss_from_logits( + best_logits, gt_float, + alpha=self.focal_alpha, gamma=self.focal_gamma + ) # (N,) + seg_loss = self.dice_weight * dice_per.mean() + self.focal_weight * focal_per.mean() + + # ---- IoU head regression ---- + iou_pred_best = prd_iou_scores.gather(1, best_ix.view(N,1)).squeeze(1) # (N,) + iou_target_best = true_iou_per_k.gather(1, best_ix.view(N,1)).squeeze(1) # (N,) + + # IMPORTANT: prd_iou_scores are already probs in [0,1]; do NOT apply sigmoid here. + # Regress directly (L1 or MSE) to the true IoU. + iou_reg_main = self._reg_loss(iou_pred_best, iou_target_best) + + total = seg_loss + self.iou_head_weight * iou_reg_main + + # ---- optional gentle supervision over all K ---- + iou_reg_all = None + if self.supervise_all_iou and K > 1 and self.all_iou_weight > 0: + # SAME: no sigmoid; train all K scores directly toward their per-slot IoUs + iou_reg_all = self._reg_loss(prd_iou_scores, true_iou_per_k) + total = total + self.all_iou_weight * iou_reg_all + + # ---- objectness (per instance) ---- + obj_loss = None + if self.pred_obj_scores: + assert object_score_logits is not None, "object_score_logits required when pred_obj_scores=True" + obj_logits = object_score_logits.view(N, -1).mean(dim=1) # (N,) + obj_target = (gt_bool.view(N, -1).sum(dim=1) > 0).float() # (N,) + obj_loss = self._objectness_loss(obj_logits, obj_target) + total = total + self.class_weight * obj_loss + + if self.gate_by_objectness: + has_obj = (obj_target > 0.5).float() # (N,) + gate = has_obj.mean().clamp_min(1e-6) + total = self.class_weight * obj_loss + gate * (seg_loss + self.iou_head_weight * iou_reg_main) + if iou_reg_all is not None: + total = total + gate * (self.all_iou_weight * iou_reg_all) + + return { + "loss_total": total, + "loss_seg": seg_loss.detach(), + "loss_iou": iou_reg_main.detach(), + "loss_iou_all": (iou_reg_all.detach() if iou_reg_all is not None else None), + "loss_class": (obj_loss.detach() if obj_loss is not None else None), + "loss_mask": focal_per.mean().detach(), + "loss_dice": dice_per.mean().detach(), + "true_iou_best": iou_target_best.mean().detach(), + } \ No newline at end of file diff --git a/saber/finetune/trainer.py b/saber/finetune/trainer.py index bc18aae..a7c3412 100644 --- a/saber/finetune/trainer.py +++ b/saber/finetune/trainer.py @@ -106,86 +106,26 @@ def _stack_image_embeddings_from_predictor(self): hr_feats = [torch.stack([lvl[b] for b in range(B)], dim=0).to(self.device) for lvl in hr] return image_embeds, hr_feats - # p_points=0.5, p_box=0.15, p_mask=0.2, p_mask_box=0.15 - def _determine_sampling(self, N, p_points=0.4, p_box=0.15, p_mask=0.25, p_mask_box=0.25): + def _determine_sampling(self, N, p_points=0.5, p_box=0.2, p_mask=0.2, p_mask_box=0.1): """ - Decide which prompt combo each instance uses. - Returns a list[int] of length N with codes: 0 = points only 1 = box + points 2 = mask + points 3 = mask + box + points """ - # normalize to avoid drift if probs don't sum to 1 exactly probs = [p_points, p_box, p_mask, p_mask_box] s = sum(probs); probs = [p / s for p in probs] - # cumulative edges for a single uniform draw - e0 = probs[0] - e1 = e0 + probs[1] - e2 = e1 + probs[2] + # cumulative edges + e0, e1, e2 = probs[0], probs[0] + probs[1], probs[0] + probs[1] + probs[2] combo = [] - for _ in range(N): + groups = {0: [], 1: [], 2: [], 3: []} + for i in range(N): r = self._rng.random() - if r < e0: combo.append(0) - elif r < e1: combo.append(1) - elif r < e2: combo.append(2) - else: combo.append(3) - return combo - - def _process_inputs(self, N, mask_logits_full, pts_all, lbl_all, boxes_full, combo): - """ - Build per-instance prompts to feed _prep_prompts(): - - trim points when also using box/mask (keep 1–3 anchors) - - pad clicks to (N, P, 2) and (N, P) with labels=-1 for ignored slots - - select boxes/mask_logits per instance based on combo - """ - device = self.device - - # Which instances use which prompts - use_boxes = torch.tensor([c in (1, 3) for c in combo], device=device) - use_masks = torch.tensor([c in (2, 3) for c in combo], device=device) - - # ---- Trim clicks (when box/mask present we keep a few anchors to avoid over-conditioning) - pts_trim, lbl_trim = [], [] - for i, (p, l) in enumerate(zip(pts_all, lbl_all)): - if combo[i] in (1, 2, 3) and p.shape[0] > 3: - pts_trim.append(p[:3]) - lbl_trim.append(l[:3]) - else: - pts_trim.append(p) - lbl_trim.append(l) - - # ---- Pad to dense tensors; labels=-1 means "ignore" for _prep_prompts - max_p = max((p.shape[0] for p in pts_trim), default=0) - pts_pad = torch.zeros((N, max_p, 2), device=device, dtype=torch.float32) - lbl_pad = torch.full((N, max_p), -1.0, device=device, dtype=torch.float32) - for i, (p, l) in enumerate(zip(pts_trim, lbl_trim)): - if p.numel(): - pts_pad[i, :p.shape[0]] = p.to(device, dtype=torch.float32) - lbl_pad[i, :l.shape[0]] = l.to(device, dtype=torch.float32) - - # ---- Ensure boxes_full exists & is float32; supply dummy box when unused - if boxes_full is None: - boxes_full = torch.tensor([[0, 0, 1, 1]], device=device, dtype=torch.float32).expand(N, 4) - else: - boxes_full = boxes_full.to(device, dtype=torch.float32) - - boxes_sel = torch.where( - use_boxes[:, None], - boxes_full, - torch.tensor([0, 0, 1, 1], device=device, dtype=torch.float32).expand_as(boxes_full) - ) - - # ---- Gate mask logits per instance (mask prompt when requested; zeros otherwise) - mask_logits_sel = torch.where( - use_masks[:, None, None], - mask_logits_full.to(device, dtype=torch.float32), - torch.zeros_like(mask_logits_full, device=device, dtype=torch.float32) - ) - - return pts_pad, lbl_pad, boxes_sel, mask_logits_sel - + c = 0 if r < e0 else 1 if r < e1 else 2 if r < e2 else 3 + combo.append(c) + groups[c].append(i) + return combo, groups def forward_step(self, batch): """ @@ -217,70 +157,155 @@ def forward_step(self, batch): return None, None, None, None inst_img_ix = torch.tensor(inst_img_ix, device=self.device, dtype=torch.long) - # 3) prompt combos - combo = self._determine_sampling(N) - - # 4) boxes - boxes_full = torch.stack(box_all, dim=0).to(self.device, dtype=torch.float32) if len(box_all) > 0 else None - if boxes_full is None: - boxes_full = torch.tensor([[0, 0, 1, 1]], device=self.device, dtype=torch.float32).expand(N, 4) + # (A) build raw per-instance tensors + boxes_full = torch.stack(box_all, dim=0).to(self.device, dtype=torch.float32) if len(box_all) > 0 else \ + torch.tensor([[0,0,1,1]], device=self.device, dtype=torch.float32).expand(N,4) # 5) mask logits (+/-6) gt_masks_bin = torch.stack([m.to(torch.float32) for m in gt_all], dim=0).to(self.device) - mask_logits_full = (gt_masks_bin * 2.0 - 1.0) * 3.0 # Before it was ±6 + mask_logits_full = (gt_masks_bin * 2.0 - 1.0) * 3.0 - # 6) build per-instance prompts - pts_pad, lbl_pad, boxes, mask_logits = self._process_inputs( - N, mask_logits_full, pts_all, lbl_all, boxes_full, combo - ) - has_any_mask = (mask_logits is not None) and (mask_logits.abs().sum() > 0) - - # 7) prep prompts (prompt-space outputs) - mask_input, point_coords, point_labels, boxes_input = self.predictor._prep_prompts( - pts_pad, lbl_pad, - box=boxes, - mask_logits=(mask_logits if has_any_mask else None), - normalize_coords=True - ) + # pad clicks + max_p = max((p.shape[0] for p in pts_all), default=0) + pts_pad = torch.zeros((N, max_p, 2), device=self.device, dtype=torch.float32) + lbl_pad = torch.full((N, max_p), -1.0, device=self.device, dtype=torch.float32) + for i, (p, l) in enumerate(zip(pts_all, lbl_all)): + if p.numel(): + pts_pad[i, :p.shape[0]] = p.to(self.device, torch.float32) + lbl_pad[i, :l.shape[0]] = l.to(self.device, torch.float32) - # --- shape fix + spatial size for dense mask prompt --- - Hf, Wf = image_embeds_B.shape[-2], image_embeds_B.shape[-1] - target_mask_h, target_mask_w = Hf * 4, Wf * 4 - - if mask_input is not None: - mask_input = mask_input.to(self.device, dtype=torch.float32) - if mask_input.dim() == 3: - mask_input = mask_input.unsqueeze(1) # [N,1,H,W] - elif mask_input.dim() == 4 and mask_input.shape[0] == 1 and mask_input.shape[1] > 1: - mask_input = mask_input.permute(1, 0, 2, 3).contiguous() # [N,1,H,W] - if mask_input.shape[1] != 1: - mask_input = mask_input[:, :1] - if mask_input.shape[-2:] != (target_mask_h, target_mask_w): - mask_input = F.interpolate(mask_input, (target_mask_h, target_mask_w), mode="bilinear", align_corners=False) - - # 8) encode prompts (use prompt-space tensors) - sparse_embeddings, dense_embeddings = self.predictor.model.sam_prompt_encoder( - points=(point_coords, point_labels), - boxes=boxes_input, - masks=mask_input, - ) + # (B) sampling + groups + combo, groups = self._determine_sampling(N) - # 9) gather image feats per instance - image_embeds = image_embeds_B[inst_img_ix] - hr_feats = [lvl[inst_img_ix] for lvl in hr_feats_B] - - # 10) decode - low_res_masks, prd_scores, _, _ = self.predictor.model.sam_mask_decoder( - image_embeddings=image_embeds, - image_pe=self.predictor.model.sam_prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=self.predict_multimask, - repeat_image=False, - high_res_features=hr_feats, - ) + Hf, Wf = image_embeds_B.shape[-2], image_embeds_B.shape[-1] + target_mask_hw = (Hf * 4, Wf * 4) + + def _run_group(idxs, use_box, use_mask): + if not idxs: + return None + idxs_t = torch.tensor(idxs, device=self.device) + + # Slice features/images for this subgroup + img_emb = image_embeds_B[inst_img_ix[idxs_t]] # [n, C, Hf, Wf] + hr_sub = [lvl[inst_img_ix[idxs_t]] for lvl in hr_feats_B] + + # Slice raw prompts for this subgroup + pts_sub = pts_pad[idxs_t] # [n, P, 2] + lbl_sub = lbl_pad[idxs_t] # [n, P] + box_sub = boxes_full[idxs_t] if use_box else None # [n, 4] or None + mlog_sub = mask_logits_full[idxs_t] if use_mask else None # [n, Hm, Wm] or None + + # Get points + (normalized) boxes from the helper; do NOT pass mask here yet. + # This gives p_coords, p_labels, box_in in prompt space. + mask_in_dummy, p_coords, p_labels, box_in = self.predictor._prep_prompts( + pts_sub, lbl_sub, box=box_sub, mask_logits=None, normalize_coords=True + ) - # 11) upsample to image res + # Build a base mask prompt from GT (if requested) at decoder’s expected spatial size + def _resize_to_target(m): + if m is None: + return None + m = m.to(self.device, dtype=torch.float32) + if m.dim() == 3: # [n, H, W] -> [n,1,H,W] + m = m.unsqueeze(1) + if m.shape[1] != 1: # keep single-channel + m = m[:, :1] + if m.shape[-2:] != target_mask_hw: + m = F.interpolate(m, target_mask_hw, mode="bilinear", align_corners=False) + return m + + base_mask_in = _resize_to_target(mlog_sub) # GT-derived mask prompt (logits in ±6) + + # Optional soft coarsening for prompts (make it a tad blurrier + smaller magnitude) + def _soften_coarsen(mlog): # mlog: [n,1,H,W] logits + if mlog is None: + return None + mlog = mlog.clamp(-6, 6) * 0.5 # ≈ ±3 + H, W = mlog.shape[-2], mlog.shape[-1] + # down → up for slight coarsening + scale = torch.empty(1, device=mlog.device).uniform_(0.5, 0.8).item() + h2, w2 = max(8, int(H * scale)), max(8, int(W * scale)) + with torch.no_grad(): + c = F.interpolate(mlog, (h2, w2), mode="bilinear", align_corners=False) + c = F.interpolate(c, (H, W), mode="bilinear", align_corners=False) + return 0.7 * c + 0.3 * mlog + + # ----------------------------- + # Two-pass refinement (only when use_mask=True): + # Pass-1: points/box only → predict mask (no grad), then use it as the mask prompt for Pass-2. + # If Pass-1 looks junky, fall back to a softened GT prompt. + # ----------------------------- + mask_prompt_final = None + if use_mask: + use_pred_mask = torch.rand(1, device=self.device) < 0.5 # 50/50 pred vs GT prompt + + if use_pred_mask: + # PASS 1 (no mask prompt): encode prompts and decode once + sp1, dn1 = self.predictor.model.sam_prompt_encoder( + points=(p_coords, p_labels), boxes=box_in, masks=None + ) + with torch.no_grad(): # do not backprop through Pass-1 + low1, sc1, obj1, _ = self.predictor.model.sam_mask_decoder( + image_embeddings=img_emb, + image_pe=self.predictor.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sp1, + dense_prompt_embeddings=dn1, + multimask_output=True, # allow K proposals; we’ll pick best by score + repeat_image=False, + high_res_features=hr_sub, + ) + # pick best candidate by IoU score head + best_k = torch.argmax(sc1, dim=1) # [n] + n, k, h, w = low1.shape + gather_idx = best_k.view(n, 1, 1, 1).expand(n, 1, h, w) + pred_mask_in = low1.gather(1, gather_idx).detach() # [n,1,h,w], stop-grad + mask_prompt_final = _soften_coarsen(_resize_to_target(pred_mask_in)) + + # Fallback to GT prompt if predicted prompt is unusable (empty/near-zero) + if mask_prompt_final is None or (mask_prompt_final.abs().max() < 1e-4): + mask_prompt_final = _soften_coarsen(base_mask_in) + else: + # Use GT-derived prompt (softened) directly + mask_prompt_final = _soften_coarsen(base_mask_in) + # else: no mask prompt path, keep None + + # PASS 2: final decode (this is the ONLY output we train on) + sp_emb, dn_emb = self.predictor.model.sam_prompt_encoder( + points=(p_coords, p_labels), + boxes=box_in, + masks=mask_prompt_final, # None (no-mask path) or predicted/GT-mask prompt + ) + low, sc, obj, _ = self.predictor.model.sam_mask_decoder( + image_embeddings=img_emb, + image_pe=self.predictor.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sp_emb, + dense_prompt_embeddings=dn_emb, + multimask_output=self.predict_multimask, + repeat_image=False, + high_res_features=hr_sub, + ) + return (idxs, low, sc, obj) + + outs = [] + outs.append(_run_group(groups[0], use_box=False, use_mask=False)) # points + outs.append(_run_group(groups[1], use_box=True, use_mask=False)) # box+points + outs.append(_run_group(groups[2], use_box=False, use_mask=True)) # mask+points + outs.append(_run_group(groups[3], use_box=True, use_mask=True)) # mask+box+points + outs = [o for o in outs if o is not None] + + # (C) stitch back to original order + order = [] + low_list, score_list, obj_list = [], [], [] + for idxs, low, sc, obj in outs: + order.extend(idxs) + low_list.append(low); score_list.append(sc); obj_list.append(obj) + perm = torch.tensor(order, device=self.device).argsort() + + low_res_masks = torch.cat(low_list, dim=0)[perm] + prd_scores = torch.cat(score_list, dim=0)[perm] + obj_logits = torch.cat(obj_list, dim=0)[perm] + + # (D) upsample and finish target_sizes = [self.predictor._orig_hw[int(b)] for b in inst_img_ix] upsampled = [] for i in range(low_res_masks.shape[0]): @@ -289,10 +314,8 @@ def forward_step(self, batch): upsampled.append(up_i) prd_masks = torch.cat(upsampled, dim=0) - # 12) stack GT gt_masks = torch.stack(gt_all, dim=0).float() - - return prd_masks, prd_scores, gt_masks, inst_img_ix + return prd_masks, prd_scores, gt_masks, obj_logits, inst_img_ix @torch.no_grad() def validate_step(self, max_images=float('inf'), best_metric='ABIoU'): @@ -310,7 +333,7 @@ def validate_step(self, max_images=float('inf'), best_metric='ABIoU'): ) # --- local accumulators (tensor) --- - loss_keys = ["loss_total", "loss_iou", "loss_dice", "loss_mask"] + loss_keys = ["loss_total", "loss_iou", "loss_dice", "loss_mask", "loss_class"] losses_sum = {k: torch.tensor(0.0, device=self.device) for k in loss_keys} n_inst = torch.tensor(0.0, device=self.device) n_imgs = torch.tensor(0.0, device=self.device) @@ -325,11 +348,11 @@ def validate_step(self, max_images=float('inf'), best_metric='ABIoU'): out = self.forward_step(batch) if out[0] is None: continue # no instances in this batch - prd_masks, prd_scores, gt_masks = out[:3] + prd_masks, prd_scores, gt_masks, obj_logits = out[:4] local_n = torch.tensor(float(gt_masks.shape[0]), device=self.device) with self.autocast(): - batch_losses = self.loss_fn(prd_masks, prd_scores, gt_masks) + batch_losses = self.loss_fn(prd_masks, prd_scores, gt_masks, obj_logits) if best_metric == 'ABIoU': m = automask_metrics( @@ -375,6 +398,7 @@ def validate_step(self, max_images=float('inf'), best_metric='ABIoU'): "loss_iou": (losses_sum["loss_iou"] / inst_denom).item(), "loss_dice": (losses_sum["loss_dice"] / inst_denom).item(), "loss_mask": (losses_sum["loss_mask"] / inst_denom).item(), + "loss_class": (losses_sum["loss_class"] / inst_denom).item(), "num_images": int(img_denom), } out.update({k: (metrics_sum[k] / img_denom).item() for k in self.metric_keys}) @@ -391,13 +415,19 @@ def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 1e4): """ # Initialize the loss function + weight_dict = {"loss_mask": 1.0, "loss_dice": 20.0, "loss_iou": 1.0, "loss_class": 1.0} self.loss_fn = MultiMaskIoULoss( - weight_dict={"loss_mask": 20.0, "loss_dice": 1.0, "loss_iou": 1.0}, - focal_alpha=self.focal_alpha, - focal_gamma=self.focal_gamma, - supervise_all_iou=self.supervise_all_iou, - iou_use_l1_loss=self.iou_use_l1_loss + supervise_all_iou=True, # start focused; re-enable later if you want + all_iou_weight=0.1, + weight_dict=weight_dict, + pred_obj_scores=True, + focal_alpha=0.25, focal_gamma=2.0, + focal_alpha_obj=0.5, focal_gamma_obj=0.0, # objectness as plain-ish BCE + gate_by_objectness=False, ) + # in trainer.__init__ after building self.loss_fn + self.iou_head_weight_base = 1.0 # for later + self.loss_fn.iou_head_weight = 0.25 # start gentle # Initialize the metric keys if best_metric == 'ABIoU': @@ -408,11 +438,10 @@ def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 1e4): 'AR', 'R@10', 'R@50', 'R@100'] # Cosine scheduler w/Warmup ---- - # warmup_epochs = max(int(0.01 * num_epochs), 1) - warmup_epochs = 5 - self.warmup_sched = LinearLR(self.optimizer, start_factor=0.1, total_iters=warmup_epochs) - self.cosine_sched = CosineAnnealingLR(self.optimizer, T_max=(num_epochs - warmup_epochs), eta_min=1e-6) - self.scheduler = SequentialLR(self.optimizer, [self.warmup_sched, self.cosine_sched], milestones=[warmup_epochs]) + self.warmup_epochs = 10 + self.warmup_sched = LinearLR(self.optimizer, start_factor=0.1, total_iters=self.warmup_epochs) + self.cosine_sched = CosineAnnealingLR(self.optimizer, T_max=(num_epochs - self.warmup_epochs), eta_min=1e-6) + self.scheduler = SequentialLR(self.optimizer, [self.warmup_sched, self.cosine_sched], milestones=[self.warmup_epochs]) # Progress bar only on rank 0 if self.is_global_zero: @@ -425,22 +454,27 @@ def train(self, num_epochs, best_metric = 'ABIoU', resample_frequency = 1e4): best_metric_value = float('-inf') # Main Loop for epoch in range(num_epochs): - # Train - # at start of each epoch - if hasattr(self.train_loader, "sampler") and hasattr(self.train_loader.sampler, "set_epoch"): - self.train_loader.sampler.set_epoch(epoch) - if (epoch+1) % resample_frequency == 0: - self.train_loader.dataset.resample_epoch() + # Scale the all_iou_weights based on epoch + if epoch < self.warmup_epochs: + self.loss_fn.supervise_all_iou = False + self.loss_fn.all_iou_weight = 0.0 + + else: + self.loss_fn.supervise_all_iou = True + self.loss_fn.all_iou_weight = min(0.05, 0.01*(epoch - self.warmup_epochs + 1)) + # gentle ramp for the IoU head + t = epoch - self.warmup_epochs + 1 + self.loss_fn.iou_head_weight = min(1.0, 0.25 + 0.15 * t) # e.g., 0.25→1.0 over ~6 epochs self.predictor.model.train() for batch in self.train_loader: out = self.forward_step(batch) if out[0] is None: continue - prd_masks, prd_scores, gt_masks = out[:3] + prd_masks, prd_scores, gt_masks, obj_logits = out[:4] with self.autocast(): - batch_losses = self.loss_fn(prd_masks, prd_scores, gt_masks) + batch_losses = self.loss_fn(prd_masks, prd_scores, gt_masks, obj_logits) if self.use_fabric: self.fabric.backward(batch_losses['loss_total']) else: @@ -511,6 +545,7 @@ def _reduce_losses(self, losses, num_elems: int = None): "loss_iou": "loss_iou", "loss_dice": "loss_dice", "loss_mask": "loss_mask", + "loss_class": "loss_class", "loss_total": "loss_total", } count = torch.tensor(float(num_elems if num_elems is not None else 1.0), device=self.device)