Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,21 @@ dependencies = [
"monai",
"click",
"copick",
"kornia",
"nibabel",
"mrcfile",
"starfile",
"lightning",
"matplotlib",
"kornia",
"opencv-python",
"multiprocess",
"torchmetrics",
"scikit-learn",
"ipywidgets",
"umap-learn",
"torch-ema",
"tensorboard",
"multiprocess",
"torchmetrics",
"scikit-learn",
"copick-utils",
"opencv-python",
]

[project.optional-dependencies]
Expand Down
36 changes: 27 additions & 9 deletions saber/classifier/datasets/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
Compose, EnsureChannelFirstd, NormalizeIntensityd, Orientationd,
RandRotate90d, RandFlipd, RandScaleIntensityd, RandShiftIntensityd,
RandAdjustContrastd, RandGaussianNoised, RandAffined, RandomOrder,
RandGaussianSmoothd,
RandGaussianSmoothd, SqueezeDimd, ToNumpyd, EnsureTyped, ResizeD
)
from saber.classifier.datasets.RandMaskCrop import AdaptiveCropd
from torch.utils.data import random_split
import torch

def get_preprocessing_transforms(random_translations=False):
transforms = Compose([
Expand All @@ -20,13 +21,6 @@ def get_preprocessing_transforms(random_translations=False):

def get_training_transforms():
train_transforms = Compose([
# RandAffined(
# keys=["image", "mask"],
# prob=0.75,
# translate_range=(30, 30),
# padding_mode="border",
# mode=("bilinear", "nearest")
# ),
RandomOrder([
RandRotate90d(keys=["image", "mask"], prob=0.5, spatial_axes=[0, 1]),
RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=0),
Expand All @@ -45,4 +39,28 @@ def get_validation_transforms():
def split_dataset(dataset, val_split=0.2):
train_size = int(len(dataset) * (1 - val_split))
val_size = len(dataset) - train_size
return random_split(dataset, [train_size, val_size])
return random_split(dataset, [train_size, val_size])

def get_finetune_transforms(target_size=(1024,1024)):
transforms = Compose([
EnsureChannelFirstd(keys=["image", "mask"], channel_dim="no_channel"),
EnsureTyped(keys=["image", "mask"], dtype=[torch.float32, torch.int64]),
ResizeD(
keys=["image", "mask"],
spatial_size=target_size,
mode=("bilinear", "nearest")
),
RandomOrder([
RandRotate90d(keys=["image", "mask"], prob=0.5, spatial_axes=[0, 1]),
RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=0),
RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=1),
RandScaleIntensityd(keys="image", prob=0.5, factors=(0.85, 1.15)),
RandShiftIntensityd(keys="image", prob=0.5, offsets=(-0.15, 0.15)),
RandAdjustContrastd(keys="image", prob=0.5, gamma=(0.85, 1.15)),
RandGaussianNoised(keys="image", prob=0.5, mean=0.0, std=1.5),
RandGaussianSmoothd(keys="image", prob=0.5, sigma_x=(0.25, 1.5), sigma_y=(0.25, 1.5)),
]),
SqueezeDimd(keys=["image", "mask"], dim=0),
ToNumpyd(keys=["image", "mask"]),
])
return transforms
63 changes: 46 additions & 17 deletions saber/classifier/preprocess/split_merge_data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
from sklearn.model_selection import train_test_split
from typing import List, Tuple, Dict, Optional
from zarr.convenience import copy as zarr_copy
from pathlib import Path
import click, zarr, os
import numpy as np

def copy_like(src_arr, dst_group, path: str):
# Ensure parent groups for nested paths like "labels/0"
parent = dst_group
parts = path.split('/')
for p in parts[:-1]:
parent = parent.require_group(p)

leaf = parts[-1]
# Create (or reuse) the dst array with identical metadata
dst_arr = parent.require_dataset(
leaf,
shape=src_arr.shape,
dtype=src_arr.dtype,
chunks=src_arr.chunks,
compressor=src_arr.compressor,
filters=src_arr.filters,
order=src_arr.order,
fill_value=src_arr.fill_value,
# Optional (if your sources use it):
**({"dimension_separator": getattr(src_arr, "dimension_separator", None)}
if hasattr(src_arr, "dimension_separator") else {})
)
dst_arr[:] = src_arr[:] # fast data copy
# Preserve array attrs
if src_arr.attrs:
dst_arr.attrs.update(src_arr.attrs)

def split(
input: str,
ratio: float,
Expand Down Expand Up @@ -57,19 +85,25 @@ def split(
items = ['0', 'labels/0', 'labels/rejected']
print('Copying data to train zarr file...')
for key in train_keys:
train_zarr.create_group(key) # Explicitly create the group first
copy_attributes(zfile[key], train_zarr[key])
dst_grp = train_zarr.require_group(key)
copy_attributes(zfile[key], dst_grp)
for item in items:
train_zarr[key][item] = zfile[key][item][:] # [:] ensures a full copy
copy_attributes(zfile[key]['labels'], train_zarr[key]['labels'])
try:
copy_like(zfile[key][item], dst_grp, item)
copy_attributes(zfile[key][item], dst_grp[item])
except Exception as e:
pass

print('Copying data to validation zarr file...')
for key in val_keys:
val_zarr.create_group(key) # Explicitly create the group first
copy_attributes(zfile[key], val_zarr[key])
dst_grp = val_zarr.require_group(key)
copy_attributes(zfile[key], dst_grp)
for item in items:
val_zarr[key][item] = zfile[key][item][:] # [:] ensures a full copy
copy_attributes(zfile[key]['labels'], val_zarr[key]['labels'])
try:
copy_like(zfile[key][item], dst_grp, item)
copy_attributes(zfile[key][item], dst_grp[item])
except Exception as e:
pass

# Print summary
print(f"\nSplit Summary:")
Expand Down Expand Up @@ -124,18 +158,16 @@ def merge(inputs: List[str], output: str):
write_key = session_label + '_' + key

# Create the group and copy its attributes
new_group = mergedZarr.create_group(write_key) # Explicitly create the group first
copy_attributes(zfile[key], new_group)
dst_grp = mergedZarr.require_group(write_key)
copy_attributes(zfile[key], dst_grp)

# Copy the data arrays
for item in items:
try:
# [:] ensures a full copy
mergedZarr[write_key][item] = zfile[key][item][:]
copy_like(zfile[key][item], dst_grp, item)
copy_attributes(zfile[key][item], dst_grp[item])
except Exception as e:
pass
# Copy attributes for labels subgroup
copy_attributes(zfile[key]['labels'], new_group['labels'])

# Copy all attributes from the last input zarr file
for attr_name, attr_value in zfile.attrs.items():
Expand Down Expand Up @@ -194,8 +226,5 @@ def copy_attributes(source, destination):
"""
if hasattr(source, 'attrs') and source.attrs:
destination.attrs.update(source.attrs)

if __name__ == '__main__':
cli()


Empty file added saber/finetune/README.md
Empty file.
Empty file added saber/finetune/__init__.py
Empty file.
Loading
Loading