diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..0d79e9d4e --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/ \ No newline at end of file diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/README.MD b/recognition/43711451_HipMRI2D_AttentionUNET/README.MD new file mode 100644 index 000000000..603bbd4c3 --- /dev/null +++ b/recognition/43711451_HipMRI2D_AttentionUNET/README.MD @@ -0,0 +1,92 @@ +# Title + 2D HipMRI Dataset Segmented using Attention U-Net + +# Problem Description + Segment the HipMRI Study on Prostate Cancer (see Appendix for link) using the processed 2D slices (2D images) with the 2D CAN with all labels having a minimum Dice similarity coefficient of 0.75 on the test set on the prostate label. You will need to load Nifti file format and sample code provided. + +# Algorithm Description +## Data Preparation + Data was pre-sorted into training, validation and testing datasets as well as segment masks for each. + The data was wrapped as if it were 3D in order to standardise the inputs as they were multiclass rather than binary. + After it was wrapped, testing data was shuffled, normalised and put into a dataset and dataloader for training and validation. + +## Model Architecture +### Filters + The encoder steps were filtered from [64, 128, 256, 512, 1024], the inverse of [1024, 512, 256, 128, 64] was there for the decoder +### Downsampling (Encoder) + The encoder carries the input through the double convultions and skip connections through each filter. + +### Upsampling (Decoder) + The decoder acts inversely though to the encoder and receives the input though the skip connections and the bottleneck. + +### Bottleneck + The bottleneck applies the double convolutions using the 3x3 kernel. + The output is then upsampled through the decoder. + +### Attention Gates + As I've tried to make this an Attention U-Net by using Attention gates to add additional context to the U-Net structure by adding a layer of mapping semantic context from the decoder and spatial detail from the encoder alongside the regular connections. + +## Model Training + Model was trained using all training samples and validated using the validation samples. + A random seed was also applied to try and generalise the training and prevent issues with just learning the training data. + +## Model Performance + The model was run and saved using batches of size 16 and initially 25 epochs using the full testing and validation sample. + I actually got 25 epochs out of it but messed up my save location so it didn't save. I really would've liked to refine that :c. + I was running on an Nvidia RTX 3060Ti with 8GB of VRAM. + After 20~ Epochs the was overfitting and this could have been optimised better by experimenting with smaller subsets with increased epochs. + +## Model Testing + After training, the model was tested using the predict function to then be put ont + +# How it Works +## File Structure +. +├── recognition +│ ├── mask_output +│ │ ├── pred_mask_case_040_week_0_slice_0.nii.gz +│ │ └── pred_mask_case_040_week_0_slice_0.nii.png +│ │ ├── pred_mask_case_040_week_1_slice_0.nii.gz +│ │ └── pred_mask_case_040_week_1_slice_0.nii.png +│ │ ├── pred_mask_case_040_week_2_slice_0.nii.gz +│ │ └── pred_mask_case_040_week_2_slice_0.nii.png +│ ├── saved_models +│ │ └── full_set_5_epochs.pth +│ ├──dataset.py +│ ├──modules.py +│ ├──predict.py +│ ├──README.MD +│ ├──train.py +│ └── utils.py +├── LICENSE +└── README.md + +## Requirements & Dependencies + Torch (for Keras layers, models, and tensoring) + NumPy (for numerical operations) + Matplotlib (for plotting) + NiBabel (for neuroimaging data handling) + NiLearn (For image resampling) + Tqdm (for progress bars) + Pathlib (for filesystem path manipulations) + Random (for random seeding) + +## Future Improvements + I honestly would really like to adapt this towards 3D Datasets and visualising it in nicer/funner ways. + I also would like to just clean up the hard-coded parameters and make it nicer to just run in a more central way. + I would've liked to add functionality to the predict section to compare the created masks to the provided one in a more clear way like analysing the area that was accurate. + Honestly a lot of the predict section was really exciting but I didn't really expand on it as much as I would've liked to. + Tbh this doc too ;~; + +## Usage & Reproduction Steps + For Training: + python3 train.py + For predictions and Visualisations + python3 predict.py + + Hyperparameters are present at the top of both to allow for customisation to how it's run. + Set the dataset locations in the dataset.py path variable too. + Subsets are able to be configured using the SUBSET param to make it smaller. (0 = all the data) + +# Visualisations + Output prediction visualisations are present in the 'recognition\43711451_HipMRI2D_AttentionUNET\mask_output' folder diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/dataset.cpython-311.pyc b/recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/dataset.cpython-311.pyc new file mode 100644 index 000000000..0377accc5 Binary files /dev/null and b/recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/dataset.cpython-311.pyc differ diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/modules.cpython-311.pyc b/recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/modules.cpython-311.pyc new file mode 100644 index 000000000..ffe1d7cf4 Binary files /dev/null and b/recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/modules.cpython-311.pyc differ diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/train.cpython-311.pyc b/recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/train.cpython-311.pyc new file mode 100644 index 000000000..82043a400 Binary files /dev/null and b/recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/train.cpython-311.pyc differ diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/utils.cpython-311.pyc b/recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/utils.cpython-311.pyc new file mode 100644 index 000000000..5e65f091d Binary files /dev/null and b/recognition/43711451_HipMRI2D_AttentionUNET/__pycache__/utils.cpython-311.pyc differ diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/dataset.py b/recognition/43711451_HipMRI2D_AttentionUNET/dataset.py new file mode 100644 index 000000000..b47f6441f --- /dev/null +++ b/recognition/43711451_HipMRI2D_AttentionUNET/dataset.py @@ -0,0 +1,274 @@ +# recognition\43711451_HipMRI2D_AttentionUNET\dataset.py +""" +Contains the data loader and preprocessing for the HipMRI 2D Slice Dataset to be used by the model +""" + +import numpy as np +import nibabel as nib +from nibabel import Nifti1Image +from nilearn.image import resample_to_img +from tqdm import tqdm +from pathlib import Path +import torch +from torch.utils.data import Dataset, DataLoader + +__author__ = "Cleodora Kizmann" +__copyright__ = "Copyright 2025, Cleodora Kizmann" +__credits__ = ["Cleodora Kizmann"] +__license__ = "Apache License 2.0" +__version__ = "1.0.1" +__maintainer__ = "Cleodora Kizmann" +__email__ = "cleodora.kizmann@student.uq.edu.au" +__status__ = "Prototype" + +# Dataset path +path = "D:/keras_slices_data/keras_slices_" # Adjust this path as needed + +# Hyperparameters +BATCH_SIZE = 16 # I got 8GB VRAM on my GPU so I might be pushing this a little +SUBSET = 25 + +def to_channels(arr: np.ndarray, num_classes: int, dtype = np.uint8)-> np.ndarray: + """ + Converts an integer label array into a one-hot encoded array. + + Args: + arr: The input 2D mask array (H, W). + num_classes: The total number of classes. + dtype: The data type of the output array. + + Returns: + A one-hot encoded array of shape (H, W, num_classes). + """ + res = np.zeros(arr.shape + (num_classes,), dtype = dtype) + + for c in range(num_classes): + # Set the channel 'c' to 1 where the input array has label 'c' + res[..., c] = (arr == c) + return res + +def standardise(img_path): + """ + Helper for if the file is (H, W, 1), it rebuilds it as (H, W, 1). + + Args: + img_path: Path to the NIfTI image file. + Returns: + A Nifti1Image object with standardized dimensions. + """ + nii = nib.load(img_path) + + if len(nii.shape) == 2: + # Data is 2D (H, W). We need to make it 3D (H, W, 1). + data_2d = nii.get_fdata(caching = "unchanged") # Shape (H, W) + data_3d = np.expand_dims(data_2d, axis = -1) # Shape (H, W, 1) + + # Re-create the NIfTI object with the new 3D data + new_nii = Nifti1Image(data_3d, nii.affine, nii.header) + + # Manually update the header to reflect the 3D shape + new_nii.header.set_data_shape(data_3d.shape) + return new_nii + elif len(nii.shape) == 3: + # It's already 3D, just return it. + return nii + return nii + +# load medical image functions +def load_data_2D(imageNames, normalise = False, categorical = False, num_classes = None, dtype = np.float32, getAffines = False, first_n = 0): + """ + Load medical image data from names, cases list provided into a list for each. + Altered to account for slices being different sizes, by resampling to a template image. + + Args: + imageNames: list of paths to NIfTI image files + normalise: bool (normalise the image 0.0-1.0) + categorical: bool (If True, 'num_classes' must also be provided) + num_classes: int (The total number of classes for one-hot encoding, e.g., 6) + getAffines: bool (Return the affine matrices along with the images) + first_n: int (Stop loading after n images for quick loading and testing scripts) + + Returns: + images: np.ndarray of shape (N, H, W) or (N, H, W, C) depending on 'categorical' + affines: list of affine matrices (if getAffines is True) + """ + # Validate mask and classes inputs + if categorical and num_classes is None: + raise ValueError("You should specify the number of classes when loading categorical mask data.") + + affines = [] # Spatial coordinates list + + # Load a template image to get dimensions + try: + template_nifti = standardise(imageNames[0]) + except Exception as e: + print(f"Error loading template image: {imageNames[0]}. {e}") + return + + num = len(imageNames) if first_n == 0 else first_n + + first_case = template_nifti.get_fdata(caching="unchanged") + + if len(first_case.shape) == 3: + first_case = first_case [:,:,0] # sometimes extra dims, remove to keep 2D slice + + if categorical: + # first_case = to_channels(first_case, dtype = dtype) + rows, cols = first_case.shape + channels = num_classes + images = np.zeros((num, rows, cols, channels), dtype = dtype) + else: + rows, cols = first_case.shape + images = np.zeros((num, rows, cols), dtype = dtype) + + if categorical: + interpolation = "nearest" # Preserve integer labels + else: + interpolation = "linear" # Average pixels for smooth image + + for i, inName in enumerate(tqdm(imageNames[:num])): + niftiImage = standardise(inName) # Loads the image + # resampled nifti to match template + resampled_nifti = resample_to_img( + niftiImage, + template_nifti, + interpolation = interpolation, + # Suppressing annoying warnings + force_resample=True, + copy_header=True + ) + # Get data from the *resampled* image + inImage = resampled_nifti.get_fdata(caching = "unchanged") # read disk only + affine = resampled_nifti.affine + if len(inImage.shape) == 3: + inImage = inImage [:,:,0] # sometimes extra dims in HipMRI_study data + inImage = inImage.astype(dtype) + + if normalise and not categorical: + # ~ inImage = inImage / np.linalg.norm(inImage) + # # ~ inImage = 255. * inImage / inImage.max () + inImage = (inImage - inImage.mean()) / inImage.std() + elif(normalise and categorical): + raise ValueError("You probably didn't mean to normalise categorical mask data.") + + if categorical: + inImage = to_channels(inImage, num_classes = num_classes, dtype = dtype) + images[i, :, :, :] = inImage + else: + images [i,:,:] = inImage + affines.append(affine) + + if first_n != 0 and i == first_n: + break + + if getAffines: + return images, affines + else: + return images + +class HipMRI2D(Dataset): + """ + Dataset class for segmentation for HipMRI 2D dataset. + + Args: + dataset: str, one of "train", "validate", "test" to specify which dataset to load. + first_n: int, number of samples to load for quick testing (default: 0, load all). + + Returns: + A PyTorch Dataset object that can be used with DataLoader for training/validation/testing. + """ + def __init__(self, dataset, first_n = 0): + """ + Initialize the HipMRI2D dataset. + + Args: + dataset: str, one of "train", "validate", "test" to specify which dataset to load. + first_n: int, number of samples to load for quick testing (default: 0, load all). + """ + self.dataset = load_data_2D(sorted(Path(path + dataset).glob("*.gz")), normalise = True, categorical = False, first_n = first_n) # Shape (N, H, W) + self.mask_data = load_data_2D(sorted(Path(path + "seg_" + dataset).glob("*.gz")), normalise = False, categorical = True, num_classes = 6, first_n = first_n) # Shape (N, H, W, C) + self.num_classes = self.mask_data.shape[-1] + + print(f"Image array shape: {self.dataset.shape}") # e.g., (100, 256, 128) + print(f"Mask array shape: {self.mask_data.shape}") # e.g., (100, 256, 128, 6) + + def __len__(self): + """ + Returns the total number of samples in the dataset. + + Returns: + int: Number of samples in the dataset. + """ + return len(self.dataset) + + def __getitem__(self, index): + """ + Retrieve the image and corresponding mask at the specified index. + + Args: + index: Index of the sample to retrieve. + + Returns: + A tuple (image, mask) where: + - image is the preprocessed image tensor. + - mask is the binary mask tensor for the hip region. + """ + # Get filename + image_np = self.dataset[index] # Shape (H, W) + mask_np = self.mask_data[index] # Shape (H, W, C) + + image_tensor = torch.from_numpy(image_np).float() + mask_tensor = torch.from_numpy(mask_np).float() + + image_tensor = image_tensor.unsqueeze(0) # (H, W) -> (1, H, W) + + # Permute the mask from "channels-last" to "channels-first" becuase PyTorch expects (C, H, W) + mask_tensor = mask_tensor.permute(2, 0, 1) # (H, W, C) -> (C, H, W) + + return image_tensor, mask_tensor + + def get_mean(self): + return np.mean(self.dataset) + + def get_std(self): + return np.std(self.dataset) + +class LoadData(DataLoader): + """ + Custom DataLoader for the HipMRI2D dataset. + Inherits from torch.utils.data.DataLoader. + """ + def __init__(self, dataset, first_n = SUBSET, batch_size=BATCH_SIZE, shuffle = True): + """ + Initialize the DataLoader. + + Args: + dataset: An instance of the HipMRI2D dataset. + batch_size: Number of samples per batch to load (default: 16). + shuffle: Whether to shuffle the data at every epoch (default: True). + """ + super().__init__(dataset=HipMRI2D(dataset, first_n = first_n), batch_size=batch_size, shuffle=shuffle) + + def __getitem__(self, dataset = "train"): + """ + Load the dataset and return a DataLoader. + + Args: + dataset: str, one of "train", "validate", "test" to specify which dataset to load. + + Returns: + A DataLoader object for the specified dataset. + """ + retrieved_dataset = HipMRI2D(dataset, first_n = SUBSET) + loaded_data = DataLoader(retrieved_dataset, batch_size = BATCH_SIZE, shuffle = True) + + return loaded_data + +if __name__ == "__main__": + print("💛 Loading training data 💛") + LoadData(dataset = "train", first_n = SUBSET, batch_size = BATCH_SIZE, shuffle = True) + print("💚 Training data loading complete 💚") + + print("💛 Loading validation data 💛") + LoadData(dataset = "validate", first_n = SUBSET, batch_size = BATCH_SIZE, shuffle = False) + print("💚 Validation data loading complete 💚") diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_0_slice_0.nii.gz b/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_0_slice_0.nii.gz new file mode 100644 index 000000000..c3692dd75 Binary files /dev/null and b/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_0_slice_0.nii.gz differ diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_0_slice_0.nii.png b/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_0_slice_0.nii.png new file mode 100644 index 000000000..16da3f4cf Binary files /dev/null and b/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_0_slice_0.nii.png differ diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_1_slice_0.nii.gz b/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_1_slice_0.nii.gz new file mode 100644 index 000000000..48a2bdb5b Binary files /dev/null and b/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_1_slice_0.nii.gz differ diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_1_slice_0.nii.png b/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_1_slice_0.nii.png new file mode 100644 index 000000000..15d562bcb Binary files /dev/null and b/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_1_slice_0.nii.png differ diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_2_slice_0.nii.gz b/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_2_slice_0.nii.gz new file mode 100644 index 000000000..50b983355 Binary files /dev/null and b/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_2_slice_0.nii.gz differ diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_2_slice_0.nii.png b/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_2_slice_0.nii.png new file mode 100644 index 000000000..22223eb33 Binary files /dev/null and b/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_case_040_week_2_slice_0.nii.png differ diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/modules.py b/recognition/43711451_HipMRI2D_AttentionUNET/modules.py new file mode 100644 index 000000000..fbdfe12d2 --- /dev/null +++ b/recognition/43711451_HipMRI2D_AttentionUNET/modules.py @@ -0,0 +1,249 @@ +# recognition\43711451_HipMRI2D_AttentionUNET\modules.py +""" +Contains the implementation of the Improved UNet segmentation model to be used for training and prediction +""" + +import torch +import torch.nn as nn + +__author__ = "Cleodora Kizmann" +__copyright__ = "Copyright 2025, Cleodora Kizmann" +__credits__ = ["Cleodora Kizmann"] +__license__ = "Apache License 2.0" +__version__ = "1.0.1" +__maintainer__ = "Cleodora Kizmann" +__email__ = "cleodora.kizmann@student.uq.edu.au" +__status__ = "Prototype" + +# Device configuration +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Hyperparameters +BATCH_SIZE = 16 # I got 8GB VRAM on my GPU so I might be pushing this a little +NUM_CLASSES = 6 + +class AttentionGate(nn.Module): + """ + Attention Gate for U-Net skip connections. + + Args: + gating_channels: Number of channels in the gating signal (from decoder). + skip_channels: Number of channels in the skip connection (from encoder). + inter_channels: Number of intermediate channels. + """ + + def __init__(self, gating_channels, skip_channels, inter_channels): + """ + Initialise the Attention Gate. + + Args: + gating_channels: Number of channels in the gating signal (from decoder). + skip_channels: Number of channels in the skip connection (from encoder). + inter_channels: Number of intermediate channels. + """ + super(AttentionGate, self).__init__() + self.gating_signal = nn.Sequential( + nn.Conv2d(gating_channels, inter_channels, kernel_size = 1, stride = 1, padding = 0, bias = True), + nn.BatchNorm2d(inter_channels) + ) + + self.skip_connection = nn.Sequential( + nn.Conv2d(skip_channels, inter_channels, kernel_size = 1, stride = 1, padding = 0, bias = True), + nn.BatchNorm2d(inter_channels) + ) + + self.attention_map = nn.Sequential( + nn.Conv2d(inter_channels, 1, kernel_size = 1, stride = 1, padding = 0, bias = True), + nn.BatchNorm2d(1), + nn.Sigmoid() + ) + + self.relu = nn.ReLU(inplace = True) + + def forward(self, g, x): + """ + Forward pass of the Attention Gate. + + Args: + g: Gating signal (from decoder) + x: Skip connection (from encoder) + """ + g1 = self.gating_signal(g) + x1 = self.skip_connection(x) + attention_map = self.relu(g1 + x1) + attention_map = self.attention_map(attention_map) + return x * attention_map + +class DoubleConv(nn.Module): + """ + (Convolution -> BatchNorm -> ReLU) * 2 + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + """ + + def __init__(self, in_channels, out_channels): + """ + Initialise the Double Convolution block. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + """ + super().__init__() + self.double_conv = nn.Sequential( + # First convolution + nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + + # Second convolution + nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + """ + Forward pass of the Double Convolution block. + + Args: + x: Input tensor. + + Returns: + Output tensor after two convolutions. + """ + return self.double_conv(x) + +class AttentionUNet(nn.Module): + """ + U-Net architecture for image segmentation. + Altered to include Attention Gates in skip connections. + + Args: + num_channels: Number of input image channels (Keeping it at one since the MRIs are going to be greyscale). + num_classes: Number of output classes (we going 6 for this one). + """ + + def __init__(self, num_channels = 1, num_classes = 6): + """ + Initialise the Attention U-Net model. + + Args: + num_channels: Number of input image channels. + num_classes: Number of output classes. + """ + super(AttentionUNet, self).__init__() + self.num_channels = num_channels + self.num_classes = num_classes + + # Encoder (Down Path) + # Each 'inc', 'down1', 'down2', etc., is a "step" in the U. + self.inc = DoubleConv(num_channels, 64) + self.down1 = DoubleConv(64, 128) + self.down2 = DoubleConv(128, 256) + self.down3 = DoubleConv(256, 512) + self.down4 = DoubleConv(512, 1024) # The bottleneck + # Max pooling for down-sampling + self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2) + + # Decoder (Up Path) + # ConvTranspose2d for up-sampling doubles the H/W and halves the channels. + self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size = 2, stride = 2) + self.conv1 = DoubleConv(1024, 512) # 512 (from up) + 512 (from skip) = 1024 + + self.up2 = nn.ConvTranspose2d(512, 256, kernel_size = 2, stride = 2) + self.conv2 = DoubleConv(512, 256) # 256 (from up) + 256 (from skip) = 512 + + self.up3 = nn.ConvTranspose2d(256, 128, kernel_size = 2, stride = 2) + self.conv3 = DoubleConv(256, 128) # 128 (from up) + 128 (from skip) = 256 + + self.up4 = nn.ConvTranspose2d(128, 64, kernel_size = 2, stride = 2) + self.conv4 = DoubleConv(128, 64) # 64 (from up) + 64 (from skip) = 128 + + # Output Layer + # Final 1x1 convolution to map to the number of classes + self.outc = nn.Conv2d(64, num_classes, kernel_size=1) + + # ATTENTION GATES + # gating_channels = channels from decoder (up-sampled) + # skip_channels = channels from encoder (skip connection) + # inter_channels = intermediate channels (can be half of skip_channels) + self.Att1 = AttentionGate(gating_channels = 512, skip_channels = 512, inter_channels = 256) + self.Att2 = AttentionGate(gating_channels = 256, skip_channels = 256, inter_channels = 128) + self.Att3 = AttentionGate(gating_channels = 128, skip_channels = 128, inter_channels = 64) + self.Att4 = AttentionGate(gating_channels = 64, skip_channels = 64, inter_channels = 32) + + def forward(self, x): + """ + Forward pass of the Attention U-Net model. + + Args: + x: Input image tensor. + + Returns: + logits: Output segmentation logits. + """ + # x is the input image, e.g., (BatchSize, 3, 256, 256) + + # Encoder + # We save the output of each encoder block to use in the skip connections + x1 = self.inc(x) # -> (B, 64, 256, 256) + x2 = self.pool(x1) # -> (B, 64, 128, 128) + x2 = self.down1(x2) # -> (B, 128, 128, 128) + + x3 = self.pool(x2) # -> (B, 128, 64, 64) + x3 = self.down2(x3) # -> (B, 256, 64, 64) + + x4 = self.pool(x3) # -> (B, 256, 32, 32) + x4 = self.down3(x4) # -> (B, 512, 32, 32) + + x5 = self.pool(x4) # -> (B, 512, 16, 16) + x5 = self.down4(x5) # -> (B, 1024, 16, 16) - This is the bottleneck + + # Decoder + up_x = self.up1(x5) # (B, 512, H/8, W/8) + + # ATTENTION GATE + # 'g' is the gating signal from decoder, 'x' is the skip connection + x4_att = self.Att1(g = up_x, x = x4) + + skip_x = torch.cat([x4_att, up_x], dim = 1) # Concatenate re-weighted skip + x = self.conv1(skip_x) + + up_x = self.up2(x) # (B, 256, H/4, W/4) + + # ATTENTION GATE + x3_att = self.Att2(g = up_x, x = x3) + + skip_x = torch.cat([x3_att, up_x], dim = 1) + x = self.conv2(skip_x) + + # Step 3 + up_x = self.up3(x) # (B, 128, H/2, W/2) + + # ATTENTION GATE + x2_att = self.Att3(g = up_x, x = x2) + + skip_x = torch.cat([x2_att, up_x], dim = 1) + x = self.conv3(skip_x) + + # Step 4 + up_x = self.up4(x) # (B, 64, H, W) + + # ATTENTION GATE + x1_att = self.Att4(g = up_x, x = x1) + + skip_x = torch.cat([x1_att, up_x], dim = 1) + x = self.conv4(skip_x) + + # Output Layer + logits = self.outc(x) + return logits + +if __name__ == "__main__": + print(f"💛 Initialising the Attention U-Net model on {device} 💛") + model = AttentionUNet(num_channels = 1 , num_classes = NUM_CLASSES) + model.to(device) + print(f"💚 Model initialisation complete on {device} 💚") diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/predict.py b/recognition/43711451_HipMRI2D_AttentionUNET/predict.py new file mode 100644 index 000000000..101d12e01 --- /dev/null +++ b/recognition/43711451_HipMRI2D_AttentionUNET/predict.py @@ -0,0 +1,214 @@ +# recognition/43711451_HipMRI2D_AttentionUNET/predict.py +""" +Contains the main prediction script for the model after training +""" + +import torch +import numpy as np +from dataset import standardise, resample_to_img +from modules import AttentionUNet as model +import nibabel as nib +from nibabel import Nifti1Image +from matplotlib import pyplot as plt +from pathlib import Path + +__author__ = "Cleodora Kizmann" +__copyright__ = "Copyright 2025, Cleodora Kizmann" +__credits__ = ["Cleodora Kizmann"] +__license__ = "Apache License 2.0" +__version__ = "1.0.1" +__maintainer__ = "Cleodora Kizmann" +__email__ = "cleodora.kizmann@student.uq.edu.au" +__status__ = "Prototype" + +# Device configuration +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Hyperparameters +BATCH_SIZE = 16 # I got 8GB VRAM on my GPU so I might be pushing this a little +SUBSET = 0 +NUM_CLASSES = 6 +NUM_EPOCHS = 25 +LEARNING_RATE = 1e-4 +SAVED_MODEL_PATH = "C:/Users/cleod/OneDrive/Documents/Work/UQ/COMP3710/Python Workspace/3710-pattern-analysis/recognition/43711451_HipMRI2D_AttentionUNET/saved_models/full_set_5_epochs.pth" +TEMPLATE_IMG_PATH = "D:/keras_slices_data/keras_slices_train/case_004_week_0_slice_0.nii.gz" +SELECTED_SLICE = "case_040_week_3_slice_0.nii.gz" +INPUT_IMG_PATH = f"D:/keras_slices_data/keras_slices_test/{SELECTED_SLICE}" +OUTPUT_MASK_PATH = f"C:/Users/cleod/OneDrive/Documents/Work/UQ/COMP3710/Python Workspace/3710-pattern-analysis/recognition/43711451_HipMRI2D_AttentionUNET/mask_output/pred_mask_{SELECTED_SLICE}" + +def predict(model, image_path, template): + """ + Runs inference on a single NIfTI image, replicating the training pre-processing. + + Args: + model: The trained Attention U-Net model. + image_path: Path to the input NIfTI image. + template: Nifti1Image template for resampling. + Returns: + pred_mask_np: NumPy array of the predicted mask. + """ + + print(f"Processing: {image_path}") + + # Load & Standardise + nifti_image = standardise(image_path) + + # Resample + resampled_nifti = resample_to_img(nifti_image, + template, + interpolation="linear", + # Suppressing annoying warnings + force_resample=True, + copy_header=True) + + # Get Data & Normalize + image_np = resampled_nifti.get_fdata(caching="unchanged") + + # Handle extra dims (from load_data_2D) + if len(image_np.shape) == 3: + image_np = image_np[:,:,0] # Shape (H, W) + + image_np = image_np.astype(np.float32) + + # Normalize *per image*, just like load_data_2D does + image_np = (image_np - image_np.mean()) / image_np.std() + + # Convert to Tensor + image_tensor = torch.from_numpy(image_np).float() + + # Add channel dim: (H, W) -> (1, H, W) + image_tensor = image_tensor.unsqueeze(0) + + # Add batch dim: (1, H, W) -> (1, 1, H, W) + image_tensor = image_tensor.unsqueeze(0) + + # Move to device + image_tensor = image_tensor.to(device) + + # Run Model + model.eval() # Set model to evaluation mode + with torch.no_grad(): + # Get raw logits [1, 6, H, W] + logits = model(image_tensor) + + # Post-process + # Apply softmax to get probabilities + probs = torch.softmax(logits, dim = 1) # dim=1 is the Class dimension + + # Get the most likely class index for each pixel + pred_mask = torch.argmax(probs, dim = 1) # Shape [1, H, W] + + # Remove batch dim: [1, H, W] -> [H, W] + pred_mask = pred_mask.squeeze(0) + + # Move to CPU as a NumPy array for saving + pred_mask_np = pred_mask.cpu().numpy() + + # Return the mask AND the spatial info for saving + return pred_mask_np, resampled_nifti.affine, resampled_nifti.header + +def visualize_prediction(image_path, template, pred_mask): + """ + Visualises the original image, the predicted mask, and an overlay. + + Args: + image_path: Path to the original image. + template: Nifti1Image template for resampling. + pred_mask: NumPy array of the predicted mask. + """ + + # Load & Resample original image + nii = standardise(image_path) + resampled_nii = resample_to_img(nii, + template, + interpolation="linear", + # Suppressing annoying warnings + force_resample=True, + copy_header=True) + image_np = resampled_nii.get_fdata(caching="unchanged") + + # Handle 3D (H, W, 1) -> 2D (H, W) + if len(image_np.shape) == 3: + image_np = image_np[:,:,0] + + # Load mask + mask_nii = nib.load(pred_mask) + mask_np = mask_nii.get_fdata() + + plt.figure(figsize=(18, 6)) + plt.suptitle(f"{INPUT_IMG_PATH.partition('D:/keras_slices_data/keras_slices_test/')[2]}", fontsize=16) + # Original Image + plt.subplot(1, 3, 1) + plt.title("MRI Slice") + plt.imshow(image_np, cmap='gray') # Grayscale for the photo + plt.axis('off') + + # Predicted Mask + plt.subplot(1, 3, 2) + plt.title("Predicted Mask") + plt.imshow(mask_np, cmap='viridis') + plt.axis('off') + + # Overlay + plt.subplot(1, 3, 3) + plt.title("Mask Overlay on MRI") + plt.imshow(image_np, cmap='gray') + + mask_np_masked = np.ma.masked_where(mask_np == 0, mask_np) + plt.imshow(mask_np_masked, cmap='viridis', alpha=0.6) + plt.axis('off') + + plt.tight_layout() + # Save the figure + output_fig_path = Path(OUTPUT_MASK_PATH).with_suffix('.png') + plt.savefig(output_fig_path, dpi=300) + print(f"Visualization saved to {output_fig_path}") + plt.show() + +if __name__ == "__main__": + print(f"💛 Initialising the Attention U-Net model from path:{SAVED_MODEL_PATH} on {device} 💛") + model = model(num_channels = 1, num_classes = NUM_CLASSES) + print(f"Loading saved weights from {SAVED_MODEL_PATH}...") + # map_location = device just makes sure it works even if trained on a GPU and are now predicting on a CPU. + model.load_state_dict(torch.load(SAVED_MODEL_PATH, map_location = device)) + model.to(device) + model.eval() + print(f"💚 Model loaded on {device} 💚") + + try: + print(f"Loading template from {TEMPLATE_IMG_PATH}...") + template = standardise(TEMPLATE_IMG_PATH) + except FileNotFoundError: + print(f"Error: Template image not found at {TEMPLATE_IMG_PATH}") + exit() + + try: + pred_mask, affine, header = predict( + model = model, + image_path = INPUT_IMG_PATH, + template = template, + ) + + print(f"Saving mask preditction to {OUTPUT_MASK_PATH}...") + + # Create a new NIfTI object for the mask + # We use the affine and header from the *resampled input* + # so the mask perfectly overlays it. + mask_nii = Nifti1Image(pred_mask.astype(np.int16), affine, header) + + # Update header to reflect 2D shape (or 3D with 1 slice) + mask_nii.header.set_data_shape(pred_mask.shape) + mask_nii.header.set_data_dtype(np.int16) # Save as integer + + nib.save(mask_nii, OUTPUT_MASK_PATH) + + print(f"Prediction complete. Mask saved to {OUTPUT_MASK_PATH}") + + visualize_prediction( + image_path = INPUT_IMG_PATH, + template = template, + pred_mask = OUTPUT_MASK_PATH, + ) + + except FileNotFoundError: + print(f"Error: Input image not found at {INPUT_IMG_PATH}") diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/train.py b/recognition/43711451_HipMRI2D_AttentionUNET/train.py new file mode 100644 index 000000000..86a2572f7 --- /dev/null +++ b/recognition/43711451_HipMRI2D_AttentionUNET/train.py @@ -0,0 +1,134 @@ +# recognition\43711451_HipMRI2D_AttentionUNET\train.py +""" +Contains the main training script for the model +""" + +from utils import Dice +from dataset import LoadData +from modules import AttentionUNet as model +import numpy as np +import random +import torch + +__author__ = "Cleodora Kizmann" +__copyright__ = "Copyright 2025, Cleodora Kizmann" +__credits__ = ["Cleodora Kizmann"] +__license__ = "Apache License 2.0" +__version__ = "1.0.1" +__maintainer__ = "Cleodora Kizmann" +__email__ = "cleodora.kizmann@student.uq.edu.au" +__status__ = "Prototype" + +# Device configuration +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Set random seeds for reproducibility +torch.manual_seed(42) +np.random.seed(42) +random.seed(42) +if torch.cuda.is_available(): + torch.cuda.manual_seed(42) + +# Hyperparameters +BATCH_SIZE = 16 # I got 8GB VRAM on my GPU so I might be pushing this a little +SUBSET = 0 +NUM_CLASSES = 6 +NUM_EPOCHS = 10 +LEARNING_RATE = 1e-4 +SAVED_MODEL_PATH = "C:/Users/cleod/OneDrive/Documents/Work/UQ/COMP3710/Python Workspace/3710-pattern-analysis/recognition/43711451_HipMRI2D_AttentionUNET/saved_models/full_set_x_epochs.pth" + +def train(training_loader = None, + validation_loader = None, + epochs = NUM_EPOCHS, + model = model, +): + """ + Train the Attention U-Net model with the training and validation data loaders. + + Args: + training_loader: DataLoader for training data. + validation_loader: DataLoader for validation data. + epochs: Number of training epochs. + model: The Attention U-Net model to be trained. + Returns: + losses: List of average training losses per epoch. + """ + + criterion = Dice() + optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE) + + losses = [] + + print("🤜 Starting training 🤛") + for epoch in range(epochs): + model.train() + epoch_loss = 0 + + # Training loop with progress + for images, masks in training_loader: + images = images.to(device) + masks = masks.to(device) + + optimizer.zero_grad() + outputs = model(images) + + # print(f" [DEBUG 1] image shape: {outputs.shape}, mask shape: {masks.shape}") + loss = criterion(outputs, masks) + # print("[DEBUG 2] Passed loss calculation") + + # Backward pass + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + + avg_loss = epoch_loss / len(training_loader) + losses.append(avg_loss) + + print(f"▶ Epoch {epoch+1}/{epochs} Complete: Avg Loss = {avg_loss:.4f} ▶️") + + # Start of Validation Loop + model.eval() # Set model to evaluation mode + epoch_loss_eval = 0 + + with torch.no_grad(): + for images, masks in validation_loader: + images = images.to(device) + masks = masks.to(device) + + # Forward pass only + outputs = model(images) + + loss = criterion(outputs, masks) + epoch_loss_eval += loss.item() + + avg_loss_eval = epoch_loss_eval / len(validation_loader) + + print(f"📈 Epoch {epoch+1} / {epochs}") + print(f"📈 Training Loss: {avg_loss:.4f}") + print(f"📈 Validation Loss: {avg_loss_eval:.4f}") + + if 1 - avg_loss_eval >= 0.75: # Saves Model if Validation Dice Coefficient is at least 0.75 + torch.save(model.state_dict(), SAVED_MODEL_PATH) + print(f"💲 Validation Dice Coefficient below 0.75. Model saved to {SAVED_MODEL_PATH} 💲") + else: + print("⛔ Model not saved: Validation Dice Coefficient below 0.75 ⛔") + + print("🛑 Attention!! Training complete with Attention U-Net!!! 🛑") + return losses + +if __name__ == "__main__": + print("💛 Loading training data 💛") + training_loader = LoadData(dataset = "train", first_n = SUBSET, batch_size = BATCH_SIZE, shuffle = True) + print("💚 Training data loading complete 💚") + + print("💛 Loading validation data 💛") + validation_loader = LoadData(dataset = "validate", first_n = SUBSET, batch_size = BATCH_SIZE, shuffle = False) + print("💚 Validation data loading complete 💚") + + print(f"💛 Initialising the Attention U-Net model on {device} 💛") + model = model(num_channels = 1, num_classes = NUM_CLASSES) + model.to(device) + print(f"💚 Model initialisation complete on {device} 💚") + + train(training_loader = training_loader, validation_loader = validation_loader, model = model) diff --git a/recognition/43711451_HipMRI2D_AttentionUNET/utils.py b/recognition/43711451_HipMRI2D_AttentionUNET/utils.py new file mode 100644 index 000000000..a5eeb592f --- /dev/null +++ b/recognition/43711451_HipMRI2D_AttentionUNET/utils.py @@ -0,0 +1,72 @@ +# recognition\43711451_HipMRI2D_AttentionUNET\utils.py +""" +Contains utility functions for the HipMRI 2D Slice Dataset project +""" + +import torch +import torch.nn as nn + +__author__ = "Cleodora Kizmann" +__copyright__ = "Copyright 2025, Cleodora Kizmann" +__credits__ = ["Cleodora Kizmann"] +__license__ = "Apache License 2.0" +__version__ = "1.0.1" +__maintainer__ = "Cleodora Kizmann" +__email__ = "cleodora.kizmann@student.uq.edu.au" +__status__ = "Prototype" + +class Dice(nn.Module): + """ + Dice Coefficient = (2 * |X ∩ Y|) / (|X| + |Y|) + Dice Loss = 1 - Dice Coefficient + + Args: + num_classes: int, number of classes for segmentation. + apply_softmax: bool, whether to apply softmax to predictions. + smooth (float): Smoothing factor to avoid division by zero (default: 1e-6) + """ + def __init__(self, num_classes = 6, apply_softmax=True, smooth=1e-6): + super(Dice, self).__init__() + self.num_classes = num_classes + self.apply_softmax = apply_softmax + self.smooth = smooth + + def forward(self, predictions, targets): + """ + Dice Loss calculation. + + Args: + predictions: Raw logits from model [B, C, H, W] + targets: One-hot encoded ground truth [B, C, H, W] + + Returns: + Dice Loss value. + """ + # Apply softmax if predictions are logits + if self.apply_softmax: + # Apply softmax across the channel dimension (dim=1) + predictions = torch.softmax(predictions, dim=1) + + dice_per_class = 0.0 + + for i in range(self.num_classes): + pred_class = predictions[:, i, :, :] + target_class = targets[:, i, :, :] + + # Flatten tensors using reshape to handle non-contiguous memory layout + pred_flat = pred_class.reshape(-1) + target_flat = target_class.reshape(-1) + + # Calculate intersection and union + intersection = (pred_flat * target_flat).sum() + dice_sum = pred_flat.sum() + target_flat.sum() + + dice_coeff = (2. * intersection + self.smooth) / (dice_sum + self.smooth) + + dice_per_class += dice_coeff + + avg_dice = dice_per_class / self.num_classes + + # Return Dice Loss (1 - Dice Coefficient) + return 1 - avg_dice + \ No newline at end of file