Highlights | Overview | Installation | Quick Start | Training | Evaluation | Datasets | Structure | Citation | Contact
SAMRI is an MRI-specialized adaptation of Meta AIβs Segment Anything Model (SAM), designed for accurate and efficient segmentation across diverse MRI datasets.
By fine-tuning only the lightweight mask decoder on precomputed MRI embeddings, SAMRI achieves state-of-the-art Dice and boundary accuracy while drastically reducing computational cost.
Paper can be found HERE.
- π§© Decoder-only fine-tuning β freeze SAMβs heavy image encoder and prompt encoder.
- βοΈ Two-stage pipeline β precompute embeddings β fine-tune decoder.
- π§ 1.1 M MRI pairs from 36 datasets / 47 tasks across 10+ MRI protocols.
- π 94% shorter training time and 96% fewer trainable parameters than full SAM retraining.
- π Superior segmentation on small and medium structures, with strong zero-shot generalization.
- πΌοΈ Supports box, point, and box + point prompts.
Fig.2 Overview of SAMRI: efficient two-stage training. Stage 1: precompute and store image embeddings with the frozen SAM encoder, removing redundant per-epoch computation. Stage 2: fine-tune only the lightweight mask decoder while keeping the image and prompt encoders frozenβdramatically reducing compute and memory cost.
SAMRI adapts SAM for the MRI domain by leveraging SAMβs strong visual representations while tailoring the decoder to medical structures and contrasts.
The approach:
- Precomputes embeddings using SAM ViT-B encoder on 2D MRI slices.
- Fine-tunes only the mask decoder with a hybrid focalβDice loss for domain adaptation.
This lightweight strategy allows SAMRI to train efficiently on a single GPU or multi-GPU clusters (e.g., H100 x 8), while maintaining robust accuracy across unseen datasets and imaging protocols.
This section helps you go from zero to a runnable environment for SAMRI. It includes optional prerequisites, a reproducible Conda setup, and a brief explanation of how dependency installation works.
SAMRI requires Python β₯ 3.10 and PyTorch β₯ 2.2 (CUDA or ROCm recommended).
Use a package manager like Conda to isolate dependencies per project.
- Download Anaconda
βοΈ - Download Miniconda (lightweight)
βοΈ
Verify Conda is available in command line:
conda --versionIf you already have a base environment:
conda create -n samri python=3.10 -y
conda activate samriPlease install the correct PyTorch
git clone https://github.com/wangzhaomxy/SAMRI.git
cd SAMRI
pip install .Run a quick import test in the command line:
python -c "import torch, nibabel; print('SAMRI environment ready! Torch:', torch.__version__)"If it prints without errors, your environment is correctly configured.
This project ships two entry points for running SAMRI on your data:
| Mode | File | Description |
|---|---|---|
| CLI | inference.py |
Fast segmentation from the command line. |
| Notebook | infer_step_by_step.ipynb |
Interactive visualization for detailed inspection. |
The pretrained SAMRI checkpoint can be downloaded HERE
| Model | Checkpoint | Description |
|---|---|---|
| SAMRI(box) | samri_vitb_box.pth | SAMRI checkpoint with box prompt (Easy to go) |
| SAMRI(box) | samri_vitb_box_zero.pth | Zero-shot enhanced SAMRI checkpoint with box prompt (Easy to go) |
| SAMRI(box+point) | samri_vitb_bp.pth | SAMRI checkpoint with box + point prompt (Robust higher accuracy) |
| SAMRI(box+point) | samri_vitb_bp_zero.pth | Zero-shot enhanced SAMRI checkpoint with box + point prompt (Robust higher accuracy) |
| SAM Vitb | sam_vit_b_01ec64.pth | SAM vitb checkpoint from GitHub. |
| SAM Vith | sam_vit_h_4b8939.pth | SAM vith checkpoint from GitHub. |
| MedSAM | medsam_vit_b.pth | MedSAM checkpoint from GitHub. |
Once the checkpoints are downloaded, place them in the /user_data/pretrained_ckpt/ directory.
Run SAM/SAMRI on a single NIfTI (.nii/.nii.gz) or standard image (.png/.jpg/.tif) and save the predicted mask.
Basic usage
python inference.py \
--input ./user_data/Datasets/demoSample/example_img_1.nii.gz \
--output ./user_data/Datasets/infer_output \
--checkpoint ./user_data/pretrained_ckpt/samri_vitb_bp.pth \
--model-type samri \
--device cuda \
--box X1 Y1 X2 Y2\
--point X Y \
--no-pngπ§ For Apple silicon
python inference.py \
--input ./user_data/Datasets/demoSample/example_img_1.nii.gz \
--output ./user_data/Datasets/infer_output \
--checkpoint ./user_data/pretrained_ckpt/samri_vitb_bp.pth \
--model-type samri \
--device mps \
--box X1 Y1 X2 Y2\
--point X Y \
--no-pngCLI arguments (from inference.py)
--input, -i(required): path to.nii/.nii.gzor.png/.jpg/.tif--output, -o(required): output folder where results are written--checkpoint, -c(required): path to SAM/SAMRI checkpoint (.pth)--model-type(default:vit_b): one ofvit_b | vit_h | samri(samrimaps to ViT-B backbone)--device(default:cuda): e.g.,cuda,cpu(ormpson Apple Silicon if available)--box X1 Y1 X2 Y2(required): bounding box prompt (pixels)--point X Y(optional): foreground point prompt (pixels)--no-png(flag): if set, do not save PNG; only.nii.gzmask is written. If you want save PNG, delete this line.
Outputs
<case>_seg_.nii.gzβ predicted mask saved as NIfTI with shape[1, H, W]<case>_seg_.pngβ (unless--no-png) grayscale binary mask PNG.
Example
python inference.py \
--input ./user_data/Datasets/demoSample/example_img_1.nii.gz \
--output ./user_data/Datasets/infer_output \
--checkpoint ./user_data/pretrained_ckpt/samri_vitb_bp.pth \
--model-type samri \
--device cuda \
--box 115 130 178 179\
--point 133 172π§ For Apple silicon
python inference.py \
--input ./user_data/Datasets/demoSample/example_img_1.nii.gz \
--output ./user_data/Datasets/infer_output \
--checkpoint ./user_data/pretrained_ckpt/samri_vitb_bp.pth \
--model-type samri \
--device mps \
--box 115 130 178 179\
--point 133 172
β οΈ Note:
- The input must be a 2D image.
- It is automatically normalized to 8-bit and converted to RGB to align with SAMβs internal preprocessing.
- The expected NIfTI file shape is (1, H, W) or (H, W), with (H, W, 1) also supported via automatic squeezing.
- The image input accepts dimensions in any of the following forms: HΓW, HΓWΓ1, HΓWΓ3, or HΓWΓ4.
Use the notebook to experiment with prompts and visualize each stage.
Open ./infer_step_by_step.ipynb and set the cell parameters:
# --- User configuration ---
INPUT_PATH = "/path/to/your/input.nii.gz" # or .png/.jpg
OUTPUT_DIR = "./Notebook_Visualization"
CHECKPOINT = "./checkpoints/samri_decoder.pth" # SAM / SAMRI checkpoint
MODEL_TYPE = "samri" # 'vit_b' | 'vit_h' | 'samri'
DEVICE = "cuda" # 'cuda' | 'cpu' | 'mps'
# Optional prompts (pixel coords)
BOX = [30, 40, 200, 220] # or None
POINT = [120, 140] # or None
SAVE_PNG = True # also write PNG next to the NIfTIThen run cells to:
- Load & normalize the input (NIfTI or image)
- Configure optional box/point prompts
- Run SAMRI inference
- Save:
<name>_seg_.nii.gz(+ optional<name>_seg_.png) - Display publication-friendly overlays/contours inside the notebook
The notebook uses the same image preparation and I/O utilities as the CLI, ensuring identical masks for matching inputs and prompts.
This is the simplest approach and requires no command-line arguments.
-
Open a terminal and ensure you are inside the SAMRI folder:
cd /path/to/SAMRI -
Start Jupyter Notebook:
jupyter notebook
-
In the browser window:
- Navigate to the SAMRI folder (if needed)
- Click
GUI_jupyter.ipynb
-
Run all cells
This will launch the complete SAMRI GUI, including:- Multi-view MRI display
- Paint / Erase / Box / Point tools
- SAMRI checkpoint loader
- Multi-candidate mask selector
- Thumbnail previews + score panel
No additional parameters are needed.
- Ensure the dependency is installed in a right way.
- The GUI is compatible with all SAMRI models and the SAM-vitb model. MedSAM is NOT supported because it applies a different image-preprocessing pipeline from SAM and SAMRI.
- Use Jupyter Notebook, not JupyterLab (unless compatible renderers/extensions are installed).
This section covers endβtoβend training of SAMRIβs decoder on precomputed SAM embeddings. The workflow is lightweight:
- Prepare data β 2) Precompute embeddings β 3) Train decoder.
SAMRI freezes SAMβs image encoder and fineβtunes only the mask decoder using a Dice+Focal loss.
Download & Organize Raw MRI Data Download the raw MRI datasets and organize them according to the specifications outlined in the RawData section. Verify that all files are correctly structured and complete before initiating the preprocessing step.
π§ Click and see the Raw Datasets structure
π Datasets/
βββ ACDC/
βββ Brain_Tumor_Dataset_Figshare/
βββ Brain-TR-GammaKnife-processed/
βββ CC-Tumor-Heterogeneity/
βββ CHAOS/
βββ HipMRI/
βββ ISLES-2022/
βββ Meningioma-SEG-CLASS/
βββ MSD/
βββ MSK_knee/
βββ MSK_shoulder/
βββ NCI-ISBI/
βββ Npz_dataset/
β βββ AMOSMR/
β βββ BraTS_FLAIR/
β βββ BraTS_T1/
β βββ BraTS_T1CE/
β βββ CervicalCancer/
β βββ crossmoda/
β βββ Heart/
β βββ ISLES2022_ADC/
β βββ ISLES2022_DWI/
β βββ ProstateADC/
β βββ ProstateT2/
β βββ QIN-PROSTATE-Lesion/
β βββ QIN-PROSTATE-Prostate/
β βββ SpineMR/
β βββ totalseg_mr/
β βββ WMH_FLAIR/
β βββ WMH_T1/
βββ OAI_imorphics_dess_sag/
βββ OAI_imorphics_flash_cor/
βββ OAI_imorphics_tse_sag/
βββ OAIAKOA/
βββ Picai/
βββ PROMISE/
βββ QIN-PROSTATE-Repeatability/
βββ QUBIQ/
βββ Spine/
βββ ZIB_OAI/
Link is preparing...Run the Preprocessing Script Run the following command to preprocess and save the datasets:
python image_processing.data_processing_code.data_processing \
--dataset-path /path/to/your/target-dataset \
--save-path /path/to/your/target-save-directory
β οΈ Note: The raw data may be periodically updated by the dataset authors. If error occur, please modify the corresponding scripts under./image_processing/data_processing_code/to maintain compatibility folder.
Preparing Your Own Data To apply SAMRI to your own MRI data, follow the recommended workflow below:
- Patient-wise splitting: Divide the dataset into training, validation, and testing subsets.
- Slice generation: Convert each 3D MRI volume into a series of 2D slices.
- Quality filtering: Retain only slices with a mask pixel count greater than 10. (To avoid main noise.)
- Noise removal: Manually inspect and remove noisy or corrupted imageβmask pairs (e.g., thin lines or artifacts).
- Data organization: Save the cleaned data in a well-structured directory format (e.g., training/, validation/, testing/).
Organize datasets into separate folders with patient-wise splits for training, validation, and testing. Place each image and its corresponding mask within the same directory. Examples:
./user_data/Datasets/SAMRI_train_test
βββ dataset_A/
β βββ training/
β β βββexample1*_img_*.nii.gz
β β βββexample1*_seg_*.nii.gz # matching file names (binary/label masks)
β β βββ...
β βββ validation/ # .nii.gz
β β βββ...
β βββ testing/
β β βββ...
βββ dataset_B/
β βββ training/
β βββ validation/ # .nii.gz
β βββ testing/
βββ ...
- Masks should align with images in shape and orientation.
- For 3D NIfTI, training is typically on 2D slices.
- The image and mask files should be organized in the same folder with different keys: "_img_" for images, and "_seg_" for masks, respectively. Other part of the name should be the same or in the same order after being sorted.
- The shape of the image and mask are both 1 x H x W.
Use the SAM ViT-B model to compute and cache image embeddings, thereby reducing training time and memory consumption.
python preprocess.precompute_embeddings \
--base-path ./user_data \
--dataset-path ./user_data/Datasets/SAMRI_train_test/ \
--img-sub-path train/ \
--save-path ./user_data/Datasets/Embedding_train/ \
--checkpoint ./user_data/pretrained_ckpt/sam_vit_b_01ec64.pth \
--device cudaKey args
--base-path: Root folder of the user data--dataset-path: Dataset directory.--img-sub-path: Dataset subfolder, choose from "train", "validation", and "test".--save-path: Embedding save directory.--checkpoint: The path of the SAM Vitb checkpoint.--device: Computation device, choose from "cuda", "cpu", and "mps".
The computed embeddings are saved in a .npz file containing the following keys:
- img: the image embedding
- mask: the corresponding segmentation mask
- ori_size: the original height and width of the image and mask.
SAMRI can be trained on commercial GPUs. The following example command illustrates this setup. Some HPC systems also provide interactive GPU terminals for direct command-line execution.
python train_single_gpu.py \
--model_type samri \
--batch_size 48 \
--data_path ./user_data \
--model_save_path ./user_data/Model_save \
--num-epochs 120 \
--device cuda \
--save-every 2 \
--prompts mixed \Some HPC systems provide command-line access for multi-GPU training. The following command can be used in such cases and can also be included in a SLURM script for batch execution.
python train_multi_gpus.py \
--model_type samri \
--batch_size 48 \
--data_path ./user_data \
--model_save_path ./user_data/Model_save \
--num-epochs 120 \
--save-every 2 \
--prompts mixed \Common args for training scripts:
--model_type samri: the training model type.--batch_size: perβprocess batch size (effective = batch_size Γ world_size) Mi300X(192G)=1024, A/H100(80G)=512. Lower batch size if OOM occurs.--data_path ./user_data: The training embedding folder path.--model_save_path ./user_data/Model_save: where to write checkpoints (.pth)--num-epochs: number of training epochs.--device: The model training GPU. Choose from "cuda" and "mps".--save-every: save checkpoints every x epoch.--prompts mixed: training prompts. Choose from "point", "bbox", and "mixed", where "mixed" means point+bbox prompt.
SLURM scripts are commonly used for job submission in HPC environments. The provided example can be found at ./train_multi_gpus_mi300.sh. Modify the configuration as needed to suit your specific HPC setup.
Examples:
#!/bin/bash --login
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=64
#SBATCH --mem=1T
#SBATCH --job-name=SAMRI
#SBATCH --time=7-00:00:00
#SBATCH --partition=gpu_rocm
#SBATCH --gres=gpu:mi300x:8
#SBATCH --account=xxxxx # Use your account if available.
#SBATCH --qos=sdf
#SBATCH -o /home/Documents/slurm-%j.output #The path to save output logs.
#SBATCH -e /home/Documents/slurm-%j.error #The path to save processing logs.
#SBATCH --mail-type=ALL
#SBATCH --mail-user=your_email@email.com
module load anaconda3
source $EBROOTANACONDA3/etc/profile.d/conda.sh
conda activate samri-mi300
# Dynamically assign port from job ID to avoid collisions
export MASTER_ADDR=localhost
export MASTER_PORT=$((26000 + RANDOM % 1000)) # Pick a port between 26000 ~ 26999
python train_multi_gpus.py- CUDA/ROCm OOM: lower
--batch_size; reducenum_workers; - Slow data loading: set
--num_workers 8..12 - Validation mismatch: confirm same preprocessing/normalization as training
- PyTorch/CUDA: install a build matching your CUDA version
- ROCm (AMD MI300X/MI210): use ROCm PyTorch wheels; NCCL flags above may help
- Apple Silicon (MPS): training is possible, but performance is limited compared to CUDA/ROCm
This section outlines the procedures for validating, testing, and visualizing model performance across the validation and test datasets. Models are evaluated using the following metrics:
- Dice Similarity Coefficient (DSC)
- Hausdorff Distance (HD)
- Mean Surface Distance (MSD)
Dedicated evaluation scripts are also provided for SAM, SAMRI, and MedSAM models.
This step evaluates model performance on precomputed embeddings rather than raw images.
It is efficient for internal validation because embeddings are already generated during preprocessing.
Script: ./evaluation/val_in_batch.py
python evaluation.val_in_batch.py \
--val-emb-path /path/to/val/embeddings/ \
--ckpt-path /path/to/checkpoint_directory/ \
--prompts mixed \
--device cuda \
--batch-size 64
β οΈ Notes:
- The script loads embeddings directly from
.npzfiles and runs batch evaluation.- This avoids redundant image encoding and greatly speeds up the validation process.
- Use this to measure training progress or perform hyperparameter tuning.
- The results will be saved in a CSV file under the checkpoint directory.
This evaluates the model directly on the test images (not precomputed embeddings).
Two common use cases are supported:
Use a specific checkpoint file for testing:
python evaluation.test_vis.py \
--test-image-path /path/to/test/images/ \
--ckpt-path /path/to/checkpoint.pth \
--save-path /path/to/save/results/ \
--device cuda \
--model-type samriAutomatically evaluate all .pth files in a directory:
python evaluation.test_vis.py \
--test-image-path /path/to/test/images/ \
--ckpt-path /path/to/checkpoint_directory/
--save-path /path/to/save/results/ \
--device cuda \
--model-type samriFeatures:
- Supports both SAM and SAMRI models.
- The script automatically detects single-file or multi-checkpoint folders.
- Evaluates each checkpoint and saves detailed metrics and predictions in a python pickle binary file.
β οΈ Note: MedSAM uses a distinct preprocessing and inference pipeline (see below).
Two dedicated scripts are provided to ensure MedSAM compatibility.
Runs inference using the original MedSAM architecture (from its official repository)
with added dataset loading and result-saving features.
python evaluation.test_medsam.py \
--test-image-path /path/to/test/images/ \
--ckpt-path /path/to/checkpoint.pth \
--save-path /path/to/save/results/ \
--device cuda- Each case is saved as an
.npzfile containing both the ground truth mask and predicted mask. - Useful for comparing outputs across architectures.
Processes the .npz results produced above and computes evaluation metrics:
python evaluation.test_medsam_eval.py \
--medsam-infer-path /path/to/medsam/inference/results/ \
--save-path /path/to/save/evaluation/results/- Aggregates and reports Dice, IoU, and boundary metrics.
- Produces results in the same standardized format as SAMRI evaluations.
Use the provided Jupyter notebook to visualize and compare results interactively:
Notebook: /evaluation/result_visualize_and_evaluate.ipynb
Open it and set the directories where your .npz result files were saved:
result_root = "/the/directory/of/the/evaluation/results/Eval_results/"You can:
- Compare performance between SAM, SAMRI, and MedSAM
- Generate summary plots (Dice boxplots, etc.)
SAMRI is trained on a curated 1.1 million MRI imageβmask pairs from 36 public datasets (47 segmentation tasks) spanning over 10 MRI sequences (T1, T2, FLAIR, DESS, TSE, etc.).
| Category | Example Datasets | Approx. Pairs |
|---|---|---|
| Brain | BraTS, ISLES, | 440 K |
| Knee | MSK_T2, OAI, | 286 K |
| Abdomen | AMOSMR, HipMRI | 176 K |
| Total Body | Totalseg MR | 143 K |
| Others | Prostate, MSD_kidney | 55 K |
Detailed dataset breakdowns are provided in Table S1 (Supplementary) in the paper.
Figure 4. Datasets: anatomical coverage of the 1.1 million MRI imageβmask pairs used to train SAMRI, summarized by body region (Brain 40%, Knee 26%, Abdomen 16%, Vertebrae 2.6%, Shoulder 0.5%, Thorax 0.2%, and a Whole-body/βTotal Bodyβ set 13.3%).*
The SAMRI repository is organized into modular components for preprocessing, training, evaluation, and utility functions.
Below is an overview of the folder hierarchy and their main purposes:
SAMRI/
βββ evaluation/ # Evaluation and visualization scripts
βββ image_processing/ # Data preprocessing & embedding generation
βββ segment_anything/ # SAM backbone integration
βββ utils/ # Dataloaders, losses, utilities
βββ inference.py # CLI inference entry
βββ infer_step_by_step.ipynb
βββ train_single_gpu.py
βββ train_multi_gpus.py
βββ train_multi_gpus_mi300.sh
βββ model.py
βββ setup.py
βββ README.mdClick and see the Repository Structure Details
SAMRI/
βββ evaluation/ # Model evaluation and visualization scripts
β βββ MedSAM-main/ # External MedSAM main code
β βββ result_visualize_and_evaluate.ipynb # Visualization and comparative analysis notebook
β βββ test_medsam.py # Run MedSAM inference and save predictions
β βββ test_medsam_eval.py # Evaluate MedSAM inference results (.npz files)
β βββ test_vis.py # Evaluate SAM/SAMRI models on test datasets
β βββ val_in_batch.py # Batch validation using precomputed embeddings
β βββ utils.py # Shared helper functions for result_visualize_and_evaluate.ipynb
β
βββ image_processing/ # Data preprocessing and embedding generation
β βββ data_processing_code/ # Individual dataset preprocessing scripts
β βββ process_embedding.py # Generate image embeddings for SAMRI
β
βββ segment_anything/ # SAM model integration
β βββ ... # (Meta-AI SAM model components)
β
βββ user_data/ # (Optional) Placeholder for user data or experiments
β
βββ utils/ # Core utilities shared across training/inference
β βββ dataloader.py # Dataset loading and management
β βββ losses.py # Custom loss functions (e.g., Dice + Focal)
β βββ utils.py # Configuration, device setup, and helper methods
β βββvisual.py # Visualization utilities
β
βββ infer_step_by_step.ipynb # Interactive notebook for step-by-step inference
βββ inference.py # Command-line inference script
βββ model.py # SAMRI model definition
βββ sarmi_gui(BugWarning).py # GUI version (experimental)
β
βββ train_single_gpu.py # Training script for single-GPU setups
βββ train_multi_gpus.py # Training script for multi-GPU (DDP)
βββ train_multi_gpus_mi300.sh # SLURM submission script for MI300X cluster
β
βββ setup.py # Installation and environment setup
βββ LICENSE # License file
βββ README.md # Main documentation file
If you use SAMRI in your research, please cite:
@misc{wang2025samrisegmentmodelmri,
title={SAMRI: Segment Anything Model for MRI},
author={Zhao Wang and Wei Dai and Thuy Thanh Dao and Steffen Bollmann and Hongfu Sun and Craig Engstrom and Shekhar S. Chandra},
year={2025},
eprint={2510.26635},
archivePrefix={arXiv},
primaryClass={eess.IV},
url={https://arxiv.org/abs/2510.26635},
}This repository is released under the Apache 2.0 License (or specify otherwise).
See the LICENSE file for details.
Developed at The University of Queensland (UQ),
School of Electrical Engineering and Computer Science (EECS).
Special thanks to the Bunya HPC Team for infrastructure support.
Built upon Meta AIβs Segment Anything Model (SAM), and inspired by the broader community efforts to adapt SAM to medical imaging.
We also gratefully acknowledge the MedSAM team for pioneering open-source adaptations of SAM for medical images and for releasing code/weights that served as an important baseline and point of comparison.
We thank open-source contributors and the MRI research community for dataset availability.
Zhao Wang
School of Electrical Engineering and Computer Science (EECS)
The University of Queensland, Australia
π§ zhao.wang1@uq.edu.au
Shekhar βShakesβ Chandra
ARC Future Fellow & Senior Lecturer
School of Electrical Engineering and Computer Science (EECS)
The University of Queensland, Australia
π§ shekhar.chandra@uq.edu.au


