Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
c64fbff
Updates to gitignore
hllelli2 Jun 19, 2024
2f4d58f
basic setup config for pip install
hllelli2 Jun 19, 2024
6afd99e
Added MapDataset and MapDataLoader
hllelli2 Jun 19, 2024
ef73dd2
Added custom augments and transfroms as well as tiling
hllelli2 Jun 19, 2024
2ab4260
Added tests for new code
hllelli2 Jun 19, 2024
8b9d06d
feat: Add HDF5DataStore class for handling hdf5 files
hllelli2 Jul 2, 2024
d8729a1
Add pytest configuration file and conftest.py for test fixtures
hllelli2 Jul 4, 2024
1a6ed24
Added a dataset config class to have all the defaults in one place
hllelli2 Jul 9, 2024
c8486b8
Create a class for lazy-loading a hdf5 file
hllelli2 Jul 9, 2024
3779ee4
Moved multi-processing to utils, Added an array dataset which can loa…
hllelli2 Jul 9, 2024
b9a91cb
Moved Multi-processing code here, code here to duplciate arrays from …
hllelli2 Jul 9, 2024
ab2ce2f
Changed augment to take arrays instead of map-objects to make it comp…
hllelli2 Jul 9, 2024
d6eda3a
Removed mapObject transform base
hllelli2 Jul 9, 2024
a86c910
chore: Refactor transforms module and update transform classes
hllelli2 Jul 9, 2024
41d40b3
first attempt at reducing if label/weight... is not none repeated code
hllelli2 Jul 9, 2024
6f73aae
chore: Add test fixtures and update conftest.py for test setup
hllelli2 Jul 9, 2024
07e2318
set_gpu is not in the current ccpem-utils so check added JIC
hllelli2 Jul 10, 2024
a6cc9b1
chore: Refactor HDF5DataStore class and add support for temporary dir…
hllelli2 Jul 17, 2024
05f792b
Removed "_" id logic in MapDataLoader, might reimplement later but no…
hllelli2 Jul 17, 2024
5a770f7
Refactor process_datasets function and remove unnecessary code
hllelli2 Jul 17, 2024
0a99fd7
Refactor transforms module and update transform classes
hllelli2 Jul 17, 2024
8772fa6
re-added __del__ method to HDF5Store
hllelli2 Jul 17, 2024
93c8424
Refactor MapDataset and ArrayDataset classes to handle weight tensors
hllelli2 Sep 4, 2024
9128e9a
Refactor
hllelli2 Sep 4, 2024
b2a49c1
Refactor DecomposeToSlices, MapObjectMaskCrop, and MapObjectPadding c…
hllelli2 Sep 4, 2024
ce47030
Refactor test_map_io.py to handle weight tensors and update dataset l…
hllelli2 Sep 4, 2024
5d34da0
Refactor HDF5DataStore class to handle weight tensors and implement c…
hllelli2 Sep 30, 2024
d049499
pre-commit fixes
hllelli2 Sep 30, 2024
1e65dd8
Refactor MapDataLoader to add caching mechanism for HDF5DataStore
hllelli2 Sep 30, 2024
f31870b
Refactor HDF5DataStore class to handle weight tensors and update cach…
hllelli2 Sep 30, 2024
cb501a0
Fixed tests
hllelli2 Sep 30, 2024
6cd4f6a
Refactor MapDataset and ArrayDataset to handle tile generation and sl…
hllelli2 Oct 1, 2024
34770e8
Added LRUCache test
hllelli2 Oct 1, 2024
ddbd2d2
Refactor MapDataset and ArrayDataset to improve code organization and…
hllelli2 Oct 1, 2024
6ff4461
Merge branch 'main' into CCPEM-AddingMapDataset
hllelli2 Oct 1, 2024
df5f192
Merge branch 'CCPEM-AddingMapDataset' into OptimiseHdf5
hllelli2 Oct 1, 2024
f10117d
Refactor unused variables in Rotation90Augment class
hllelli2 Oct 1, 2024
9331356
Add ccpem-utils to project dependencies
hllelli2 Oct 1, 2024
be1a80b
Add h5py and psutil to project dependencies
hllelli2 Oct 1, 2024
6b7f0c7
Fix deprecation warnings in pytest.ini
hllelli2 Oct 1, 2024
3b4196f
Fix deprecation warnings in pytest.ini
hllelli2 Oct 1, 2024
28dc529
Fix deprecation warnings in pytest.ini
hllelli2 Oct 1, 2024
2358a76
Fix deprecation warnings in pytest.ini and update pytest configuration
hllelli2 Oct 1, 2024
a9af9c3
Fix deprecation warnings
hllelli2 Oct 1, 2024
7cd0e06
Update pytest configuration and package installation
hllelli2 Oct 1, 2024
cb6a8c9
Update pytest configuration and package installation
hllelli2 Oct 1, 2024
b95eaa6
Update pytest configuration and package installation
hllelli2 Oct 1, 2024
a8b367c
Fix: __getitem__ logic
hllelli2 Oct 1, 2024
0696806
Add iterator method to HDF5DataStore and simplify key checks in tests
hllelli2 Oct 9, 2024
7fbefdf
Refactor test fixtures to use default scope and improve readability
hllelli2 Oct 9, 2024
36a5883
Refactor: Rename tiles to slice_indicies and update related logic
hllelli2 Oct 9, 2024
b24d299
Refactor: Simplify iterator method in HDF5DataStore
hllelli2 Oct 9, 2024
7d12ece
Refactor: update default voxel parameter type and improve kwargs hand…
hllelli2 Oct 9, 2024
f4555ae
Merge branch 'OptimiseHdf5' of https://github.com/alan-turing-institu…
hllelli2 Oct 9, 2024
2f32536
Updated MapDataLoader to accepted the Mapdataset Kwargs in the load f…
hllelli2 Oct 23, 2024
1844ec7
Fix: Change default value of 'vox' to a float and update related logi…
hllelli2 Oct 23, 2024
96772ae
Merge branch 'OptimiseHdf5' of https://github.com/alan-turing-institu…
hllelli2 Oct 23, 2024
b8d9c65
Cleanup: remove debug print statements from MapDataset class
hllelli2 Oct 24, 2024
d975d44
update pyproject.toml with dependency version
aj26git Dec 10, 2024
dc23c5c
Enhance MapDataLoader: Add background filtering to skip slices with e…
hllelli2 Mar 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
allow-prereleases: true

- name: Install package
run: python -m pip install .[test]
run: python -m pip install -e .[test]

- name: Test package
run: >-
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,7 @@ Thumbs.db
# Common editor files
*~
*.swp


# IDE specific files
.vscode/
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers = [
"Typing :: Typed",
]

dependencies = ["torch", "numpy", "pandas", "mrcfile", "torchvision", "scipy", "pyarrow"]
dependencies = ["torch", "numpy", "pandas", "mrcfile", "torchvision", "scipy ~= 1.9.3", "pyarrow", "ccpem-utils", "h5py", "psutil", "pillow ~= 9.3"]

[project.optional-dependencies]
test = [
Expand Down Expand Up @@ -60,7 +60,7 @@ minversion = "6.0"
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
xfail_strict = true
filterwarnings = [
"error",
"ignore::DeprecationWarning",
]
log_cli_level = "INFO"
testpaths = [
Expand Down
53 changes: 53 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Setup configuration for the package
[metadata]
name = caked


# Options for the package

[options]

packages = find:
python_requires = >=3.8
package_dir =
= src


# where to add pip dependencies

install_requires =
torch
numpy
pandas
mrcfile
torchvision
scipy
pyarrow
ccpem-utils
h5py
psutil


[options.packages.find]
where =
src
src/Transforms
src/Wrappers

exclude =
tests
.github
.gitignore
.gitattributes
.pytest_cache
.git
.vscode
.history
*.egg
*.egg-info
docs
site
mkdocs.yml
*.ipynb
.mypy_cache
.ruff_cache
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from setuptools import setup

setup()
124 changes: 124 additions & 0 deletions src/caked/Transforms/augments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

import random
from enum import Enum

import numpy as np
from ccpem_utils.map.array_utils import rotate_array
from ccpem_utils.map.parse_mrcmapobj import MapObjHandle

from .base import AugmentBase


class Augments(Enum):
""" """

RANDOMROT = "randrot"
ROT90 = "rot90"


def get_augment(augment: str, random_seed) -> AugmentBase:
""" """

if augment == Augments.RANDOMROT.value:
return RandomRotationAugment(random_seed=random_seed)
if augment == Augments.ROT90.value:
return Rotation90Augment(random_seed=random_seed)

msg = f"Unknown Augmentation: {augment}, please choose from {Augments.__members__}"
raise ValueError(msg)


class ComposeAugment:
"""
Compose multiple Augments together.

:param augments: (list) list of augments to compose

:return: (np.ndarrry) transformed array
"""

def __init__(self, augments: list[str], random_seed: int = 42):
self.random_seed = random_seed
self.augments = augments

def __call__(self, data: np.ndarray, **kwargs) -> MapObjHandle:
for augment in self.augments:
data, augment_kwargs = get_augment(augment, random_seed=self.random_seed)(
data, **kwargs
)

kwargs.update(augment_kwargs)

return data, kwargs


class RandomRotationAugment(AugmentBase):
"""
Random or controlled rotation (if ax and an kwargs provided).

:param data: (np.ndarray) 3d volume
:param return_all: (bool) if True, will parameters of the rotation (ax, an)
:param interp: (bool) if True, will interpolate the rotation
:param ax: (int) 0 for yaw, 1 for pitch, 2 for roll
:param an: (int) number of times to rotate, between <1 and 3>

:return: (np.ndarray) rotated volume or (np.ndarray, int, int) rotated volume and rotation parameters
"""

def __init__(self, random_seed: int = 42):
super().__init__(random_seed)

def __call__(
self,
data: np.ndarray,
**kwargs,
) -> np.ndarray | tuple[np.ndarray, int, int]:
ax = kwargs.get("ax", None)
an = kwargs.get("an", None)
interp = kwargs.get("interp", True)

if (ax is not None and an is None) or (ax is None and an is not None):
msg = "When specifying rotation, please use both arguments to specify the axis and angle."
raise RuntimeError(msg)
rotations = [(0, 1), (0, 2), (1, 2)] # yaw, pitch, roll
if ax is None and an is None:
axes = random.randint(0, 2)
set_angles = [30, 60, 90]
angler = random.randint(0, 2)
angle = set_angles[angler]
else:
axes = ax
angle = an

r = rotations[axes]
data = rotate_array(data, angle, axes=r, interpolate=interp, reshape=False)

return data, {"ax": axes, "an": angle}


class Rotation90Augment(AugmentBase):
"""
Rotate the volume by 90 degrees.

:param data: (np.ndarray) 3d volume
:param return_all: (bool) if True, will parameters of the rotation (ax, an)
:param interp: (bool) if True, will interpolate the rotation
:param ax: (int) 0 for yaw, 1 for pitch, 2 for roll
:param an: (int) number of times to rotate, between <1 and 3>

:return: (np.ndarray) rotated volume or (np.ndarray, int, int) rotated volume and rotation parameters
"""

def __init__(self, random_seed: int = 42):
super().__init__(random_seed)

def __call__(
self,
data: np.ndarray,
**kwargs,
) -> np.ndarray:
_ = data
_ = kwargs
msg = "Rotation90Augment not implemented yet."
raise NotImplementedError(msg)
38 changes: 38 additions & 0 deletions src/caked/Transforms/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from abc import ABC, abstractmethod

import numpy as np


class TransformBase(ABC):
"""
Base class for transformations.

"""

@abstractmethod
def __init__(self):
pass

@abstractmethod
def __call__(self, mapobj, **kwargs):
msg = "The __call__ method must be implemented in the subclass"
raise NotImplementedError(msg)


class AugmentBase(ABC):
"""
Base class for augmentations.
"""

# This will need to take the hyper parameters for the augmentations

@abstractmethod
def __init__(self, random_seed: int = 42):
self.random_state = np.random.RandomState(random_seed)

@abstractmethod
def __call__(self, data, **kwargs):
msg = "The __call__ method must be implemented in the subclass"
raise NotImplementedError(msg)
Loading
Loading