Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ BUILD.info
fastsurfer.egg-info
.codespellignore
.gitignore
uv.lock
uv.lock
2 changes: 1 addition & 1 deletion CerebNet/config/checkpoint_paths.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
url:
- "https://zenodo.org/records/10390742/files"
- "https://b2share.fz-juelich.de/api/files/c6cf7bc6-2ae5-4d0e-814d-2a3cf0e1a8c5"
- "https://zenodo.org/records/10390742/files"

checkpoint:
axial: "checkpoints/CerebNet_axial_v1.0.0.pkl"
Expand Down
2 changes: 1 addition & 1 deletion CorpusCallosum/config/checkpoint_paths.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
url:
- "https://zenodo.org/records/17141933/files"
- "https://b2share.fz-juelich.de/api/files/e4eb699c-ba68-4470-9f3d-89ceeee1a334"
- "https://zenodo.org/records/17141933/files"

checkpoint:
segmentation: "checkpoints/FastSurferCC_segmentation_v1.0.0.pkl"
Expand Down
8 changes: 3 additions & 5 deletions CorpusCallosum/localization/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from monai.networks.nets import DenseNet

from CorpusCallosum.transforms.localization import CropAroundACPCFixedSize
from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML
from CorpusCallosum.utils.types import Points2dType
from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults
from FastSurferCNN.download_checkpoints import main as download_checkpoints
from FastSurferCNN.utils import Image3d, Vector2d, Vector3d
from FastSurferCNN.utils.checkpoint import get_config_file
from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT

PATCH_SIZE = (64, 64)
Expand Down Expand Up @@ -60,10 +60,8 @@ def load_model(device: torch.device) -> DenseNet:
)

download_checkpoints(cc=True)
cc_config = load_checkpoint_config_defaults(
"checkpoint",
filename=CC_YAML,
)
config_file = get_config_file("CorpusCallosum")
cc_config = load_checkpoint_config_defaults("checkpoint", filename=config_file)
checkpoint_path = FASTSURFER_ROOT / cc_config['localization']

# Load state dict
Expand Down
8 changes: 3 additions & 5 deletions CorpusCallosum/segmentation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@

from CorpusCallosum.data import constants
from CorpusCallosum.transforms.segmentation import CropAroundACPC
from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML
from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults
from FastSurferCNN.download_checkpoints import main as download_checkpoints
from FastSurferCNN.models.networks import FastSurferVINN
from FastSurferCNN.utils import Image3d, Image4d, Shape2d, Shape3d, Shape4d, Vector2d, nibabelImage
from FastSurferCNN.utils.checkpoint import get_config_file
from FastSurferCNN.utils.parallel import thread_executor


Expand Down Expand Up @@ -70,10 +70,8 @@ def load_model(device: torch.device | None = None) -> FastSurferVINN:
model = FastSurferVINN(params)

download_checkpoints(cc=True)
cc_config: dict[str, Path] = load_checkpoint_config_defaults(
"checkpoint",
filename=CC_YAML,
)
config_file = get_config_file("CorpusCallosum")
cc_config: dict[str, Path] = load_checkpoint_config_defaults("checkpoint", filename=config_file)
checkpoint_path = constants.FASTSURFER_ROOT / cc_config['segmentation']

weights = torch.load(checkpoint_path, weights_only=True, map_location=device)
Expand Down
18 changes: 14 additions & 4 deletions CorpusCallosum/shape/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,13 +538,23 @@ def snap_cc_picture(
3. Cleans up temporary files after use.
"""
try:
# Dummy import of OpenCL to ensure it's available for whippersnappy
import OpenGL.GL # noqa: F401
from whippersnappy.core import snap1
except ImportError:
except ImportError as e:
# whippersnappy not installed
raise RuntimeError(
"The snap_cc_picture method of CCMesh requires whippersnappy, but whippersnappy was not found. "
"Please install whippersnappy!"
raise ImportError(
f"The snap_cc_picture method of CCMesh requires {e.name}, but {e.name} was not found. "
f"Please install {e.name}!",
name=e.name, path=e.path
) from None
except Exception as e:
# Catch all other types of errors,
raise RuntimeError(
"Could not import OpenGL or whippersnappy. The snap_cc_picture method of CCMesh requires OpenGL and "
"whippersnappy to render the QC thickness image. On headless servers, this also requires a virtual "
"framebuffer like xvfb.",
) from e
self.__make_parent_folder(output_path)
# Skip snapshot if there are no faces
if len(self.t) == 0:
Expand Down
19 changes: 17 additions & 2 deletions CorpusCallosum/shape/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def _gen_slice2slab_vox2vox(_slice_idx: int) -> AffineMatrix4x4:
# Mesh is fsavg_midplane (RAS); we need to transform to voxel coordinates
# fsavg ras is also on the midslice, so this is fine and we multiply in the IA and SP offsets
cc_mesh = cc_mesh.to_vox_coordinates(mesh_ras2vox=np.linalg.inv(fsavg_vox2ras @ orig2fsavg_vox2vox))
cc_surf_generated = False
if wants_output("cc_thickness_image"):
# this will also write overlay and surface
thickness_image_path = output_path("cc_thickness_image")
Expand All @@ -276,8 +277,22 @@ def _gen_slice2slab_vox2vox(_slice_idx: int) -> AffineMatrix4x4:
if wants_output("cc_thickness_overlay") else None,
"ref_image": upright_img,
}
cc_mesh.snap_cc_picture(thickness_image_path, **kwargs)
elif wants_output("cc_surf"):
try:
cc_mesh.snap_cc_picture(thickness_image_path, **kwargs)
cc_surf_generated = True
except (ImportError, ModuleNotFoundError) as e:
logger.error(
"The thickness image was not generated because whippersnappy, glfw or OpenGL are not installed."
)
logger.exception(e)
except Exception as e:
logger.error(
"The thickness image was not generated (see below). On headless Linux systems or if the "
"x-server cannot/should not be accessed due to other reasons, xvfb-run may be used to provide "
"a virtual framebuffer for offscreen rendering."
)
logger.exception(e)
if not cc_surf_generated and wants_output("cc_surf"):
surf_file_path = output_path("cc_surf")
logger.info(f"Saving surf file to {surf_file_path}")
io_futures.append(run(cc_mesh.write_fssurf, str(surf_file_path), image=upright_img))
Expand Down
2 changes: 1 addition & 1 deletion CorpusCallosum/shape/subsegment_contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ def subdivide_contour_vertical(
plt.show()




# add original contour as the final element (Full CC)
split_contours.append(contour)
Expand Down
2 changes: 1 addition & 1 deletion FastSurferCNN/config/checkpoint_paths.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
url:
- "https://zenodo.org/records/10390573/files"
- "https://b2share.fz-juelich.de/api/files/a423a576-220d-47b0-9e0c-b5b32d45fc59"
- "https://zenodo.org/records/10390573/files"

checkpoint:
axial: "checkpoints/aparc_vinn_axial_v2.0.0.pkl"
Expand Down
4 changes: 2 additions & 2 deletions FastSurferCNN/data_loader/conform.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def is_conform(
if "check_dtype" in kwargs:
LOGGER.warning("check_dtype is deprecated, replaced by dtype=None and will be removed.")
if kwargs["check_dtype"] is False:
dtype = None
dtype: npt.DTypeLike | None = None

_vox_size, _img_size = conformed_vox_img_size(img, vox_size, img_size, threshold_1mm=threshold_1mm, vox_eps=vox_eps)

Expand Down Expand Up @@ -1006,7 +1006,7 @@ def is_conform(
checks["Dtype None"] = "IGNORED", dtype_text
else:
_dtype: npt.DTypeLike = to_dtype(dtype)
_dtype_name = _dtype.name if hasattr(_dtype, "name") else str(getattr(np.dtype(_dtype), "name", dtype))
_dtype_name = np.dtype(_dtype).name if isinstance(_dtype, (str, np.dtype)) else _dtype.__name__
checks[f"Dtype {_dtype_name}"] = np.issubdtype(img.get_data_dtype(), _dtype), dtype_text

_is_conform = all(map(lambda x: x[0], checks.values()))
Expand Down
103 changes: 34 additions & 69 deletions FastSurferCNN/download_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3

# Copyright 2022 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,37 +13,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from CerebNet.utils.checkpoint import (
YAML_DEFAULT as CEREBNET_YAML,
)
from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML
from FastSurferCNN.utils import PLANES
from FastSurferCNN.utils.checkpoint import (
YAML_DEFAULT as VINN_YAML,
)
from functools import lru_cache
from itertools import chain
from pathlib import Path

from FastSurferCNN.utils.checkpoint import (
check_and_download_ckpts,
get_checkpoints,
get_config_file,
load_checkpoint_config_defaults,
)
from HypVINN.utils.checkpoint import YAML_DEFAULT as HYPVINN_YAML
from FastSurferCNN.utils.parallel import thread_executor


class ConfigCache:
def vinn_url(self):
return load_checkpoint_config_defaults("url", filename=VINN_YAML)

def cerebnet_url(self):
return load_checkpoint_config_defaults("url", filename=CEREBNET_YAML)
@classmethod
@lru_cache
def url(cls, module: str) -> list[str]:
return load_checkpoint_config_defaults("url", get_config_file(module))

def hypvinn_url(self):
return load_checkpoint_config_defaults("url", filename=HYPVINN_YAML)

def cc_url(self):
return load_checkpoint_config_defaults("url", filename=CC_YAML)
@classmethod
@lru_cache
def checkpoint(cls, module: str) -> dict[str, Path]:
return load_checkpoint_config_defaults("checkpoint", get_config_file(module))

def all_urls(self):
return self.vinn_url() + self.cerebnet_url() + self.hypvinn_url() + self.cc_url()
@classmethod
def all_urls(cls) -> list[str]:
return list(chain(*(cls.url(mod) for mod in ("FastSurferCNN", "CorpusCallosum", "CerebNet", "HypVINN"))))


defaults = ConfigCache()
Expand Down Expand Up @@ -93,9 +88,8 @@ def make_parser():
type=str,
default=None,
help=f"Specify you own base URL. This is applied to all models. \n"
f"Default for VINN: {defaults.vinn_url()} \n"
f"Default for CerebNet: {defaults.cerebnet_url()} \n"
f"Default for HypVINN: {defaults.hypvinn_url()}",
f"Default for VINN: {defaults.url('FastSurferCNN')} \n" + \
"\n".join(f"Default for {mod}: {defaults.url(mod)}" for mod in ("CerebNet", "CorpusCallosum", "HypVINN")),
)
parser.add_argument(
"files",
Expand All @@ -116,50 +110,16 @@ def main(
url: str | None = None,
) -> int | str:
if not vinn and not files and not cerebnet and not hypvinn and not cc and not all:
return ("Specify either files to download or --vinn, --cerebnet, "
"--hypvinn, or --all, see help -h.")
return "Specify either files to download or --vinn, --cerebnet, --cc, --hypvinn, or --all, see help -h."

futures = []
all_errors = []
try:
# FastSurferVINN checkpoints
if vinn or all:
vinn_config = load_checkpoint_config_defaults(
"checkpoint",
filename=VINN_YAML,
)
get_checkpoints(
*(vinn_config[plane] for plane in PLANES),
urls=defaults.vinn_url() if url is None else [url]
)
# CerebNet checkpoints
if cerebnet or all:
cerebnet_config = load_checkpoint_config_defaults(
"checkpoint",
filename=CEREBNET_YAML,
)
get_checkpoints(
*(cerebnet_config[plane] for plane in PLANES),
urls=defaults.cerebnet_url() if url is None else [url],
)
# HypVINN checkpoints
if hypvinn or all:
hypvinn_config = load_checkpoint_config_defaults(
"checkpoint",
filename=HYPVINN_YAML,
)
get_checkpoints(
*(hypvinn_config[plane] for plane in PLANES),
urls=defaults.hypvinn_url() if url is None else [url],
)
# Corpus Callosum checkpoints
if cc or all:
cc_config = load_checkpoint_config_defaults(
"checkpoint",
filename=CC_YAML,
)
get_checkpoints(
*(cc_config[model] for model in cc_config.keys()),
urls=defaults.cc_url() if url is None else [url],
)
for mod, sel in (("FastSurferCNN", vinn), ("CerebNet", cerebnet), ("HypVINN", hypvinn), ("CorpusCallosum", cc)):
if sel or all:
urls = defaults.url(mod) if url is None else [url]
futures.extend(thread_executor().submit(get_checkpoints, file, urls=urls)
for key, file in defaults.checkpoint(mod).items())
for fname in files:
check_and_download_ckpts(
fname,
Expand All @@ -168,8 +128,13 @@ def main(
except Exception as e:
from traceback import print_exception
print_exception(e)
return e.args[0]
return 0
all_errors = [e.args[0]]
for f in futures:
if e := f.exception():
from traceback import print_exception
print_exception(e)
all_errors.append(f.exception().args[0])
return "\n".join(all_errors) or 0


if __name__ == "__main__":
Expand Down
27 changes: 20 additions & 7 deletions FastSurferCNN/run_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# IMPORTS
import argparse
import sys
import warnings
from collections.abc import Iterator, Sequence
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from pathlib import Path
Expand All @@ -44,14 +45,13 @@
from FastSurferCNN.utils import PLANES, Plane, logging, nibabelImage, parser_defaults
from FastSurferCNN.utils.arg_types import OrientationType, VoxSizeOption
from FastSurferCNN.utils.arg_types import vox_size as _vox_size
from FastSurferCNN.utils.checkpoint import get_checkpoints, load_checkpoint_config_defaults
from FastSurferCNN.utils.checkpoint import get_checkpoints, get_config_file, load_checkpoint_config_defaults
from FastSurferCNN.utils.common import SubjectDirectory, SubjectList, find_device, handle_cuda_memory_exception
from FastSurferCNN.utils.load_config import load_config
from FastSurferCNN.utils.parallel import SerialExecutor, pipeline
from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT, SubjectDirectoryConfig
from FastSurferCNN.utils.parser_defaults import SubjectDirectoryConfig

LOGGER = logging.getLogger(__name__)
CHECKPOINT_PATHS_FILE = FASTSURFER_ROOT / "FastSurferCNN/config/checkpoint_paths.yaml"


##
Expand Down Expand Up @@ -223,6 +223,16 @@ def __init__(
if self.device.type == "cpu" and viewagg_device in ("auto", "cpu"):
self.viewagg_device = self.device
else:
if self.device.type == "cuda" and not torch.cuda.is_initialized():
with warnings.catch_warnings():
warnings.simplefilter("error")
try:
torch.cuda.init()
except RuntimeError as err:
LOGGER.critical("Failed to initialize cuda device, maybe incompatible CUDA version?")
LOGGER.exception(err)
raise err

# check, if GPU is big enough to run view agg on it (this currently takes the memory of the passed device)
self.viewagg_device = find_device(
viewagg_device,
Expand Down Expand Up @@ -551,11 +561,12 @@ def _add_sd_help(action: argparse.Action) -> None:
parser_defaults.modify_argument(parser, "--sd", _add_sd_help)

# 3. Checkpoint to load
config_file = get_config_file("FastSurferCNN")
files: dict[Plane, str | Path] = {k: "default" for k in PLANES}
parser = parser_defaults.add_plane_flags(parser, "checkpoint", files, CHECKPOINT_PATHS_FILE)
parser = parser_defaults.add_plane_flags(parser, "checkpoint", files, config_file)

# 4. CFG-file with default options for network
parser = parser_defaults.add_plane_flags(parser, "config", files, CHECKPOINT_PATHS_FILE)
parser = parser_defaults.add_plane_flags(parser, "config", files, config_file)

# 5. technical parameters
image_flags = ["vox_size", "conform_to_1mm_threshold", "orientation", "image_size", "device"]
Expand Down Expand Up @@ -613,8 +624,10 @@ def main(
# Download checkpoints if they do not exist
# see utils/checkpoint.py for default paths
LOGGER.info("Checking or downloading default checkpoints ...")

urls = load_checkpoint_config_defaults("url", filename=CHECKPOINT_PATHS_FILE)

config_file = get_config_file("FastSurferCNN")

urls = load_checkpoint_config_defaults("url", filename=config_file)

get_checkpoints(ckpt_ax, ckpt_cor, ckpt_sag, urls=urls)

Expand Down
3 changes: 2 additions & 1 deletion FastSurferCNN/segstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,7 +2279,8 @@ def pv_calc_patch(
for p, gc in zip(slicer_patch, global_crop, strict=False))

label_lookup = np.unique(seg[slicer_small_patch])
maxlabels = label_lookup[-1] + 1
# make sure to promote label_lookup to int64 to avoid overflow (numpy2)
maxlabels = int(label_lookup[-1]) + 1
if maxlabels > 100_000:
raise RuntimeError("Maximum number of labels above 100000!")
# create a view for the current patch border
Expand Down
Loading
Loading