From 6fe68a55582f14ed8a1fd1d437050fa96c1011f4 Mon Sep 17 00:00:00 2001 From: Pommier Date: Sat, 25 Oct 2025 17:13:26 +0200 Subject: [PATCH] Build modular PyQt cytogenetic viewer --- CytoViewer/__init__.py | 9 + CytoViewer/__main__.py | 4 + README.md | 25 ++ cyto_viewer/__init__.py | 3 + cyto_viewer/__main__.py | 5 + cyto_viewer/viewer.py | 558 ++++++++++++++++++++++++++++++++++++++ viewer/__init__.py | 5 + viewer/__main__.py | 4 + viewer/alignment.py | 97 +++++++ viewer/batch.py | 76 ++++++ viewer/config.py | 49 ++++ viewer/main.py | 545 +++++++++++++++++++++++++++++++++++++ viewer/preprocessing.py | 243 +++++++++++++++++ viewer/segmentation.py | 228 ++++++++++++++++ viewer/utils/__init__.py | 0 viewer/utils/data.py | 111 ++++++++ viewer/utils/fake_data.py | 55 ++++ 17 files changed, 2017 insertions(+) create mode 100644 CytoViewer/__init__.py create mode 100644 CytoViewer/__main__.py create mode 100644 cyto_viewer/__init__.py create mode 100644 cyto_viewer/__main__.py create mode 100644 cyto_viewer/viewer.py create mode 100644 viewer/__init__.py create mode 100644 viewer/__main__.py create mode 100644 viewer/alignment.py create mode 100644 viewer/batch.py create mode 100644 viewer/config.py create mode 100644 viewer/main.py create mode 100644 viewer/preprocessing.py create mode 100644 viewer/segmentation.py create mode 100644 viewer/utils/__init__.py create mode 100644 viewer/utils/data.py create mode 100644 viewer/utils/fake_data.py diff --git a/CytoViewer/__init__.py b/CytoViewer/__init__.py new file mode 100644 index 00000000000..c5c015c338c --- /dev/null +++ b/CytoViewer/__init__.py @@ -0,0 +1,9 @@ +"""Compatibility wrapper to allow ``python -m CytoViewer``. + +The original notebooks referenced a capitalised package name. Keep the +modern implementation in :mod:`cyto_viewer` as the single source of +truth and re-export its public entry points here. +""" +from cyto_viewer.viewer import main, run_viewer # noqa: F401 + +__all__ = ["main", "run_viewer"] diff --git a/CytoViewer/__main__.py b/CytoViewer/__main__.py new file mode 100644 index 00000000000..d24e80379d3 --- /dev/null +++ b/CytoViewer/__main__.py @@ -0,0 +1,4 @@ +from cyto_viewer.__main__ import main + +if __name__ == "__main__": + main() diff --git a/README.md b/README.md index 56403776253..949108fc416 100644 --- a/README.md +++ b/README.md @@ -51,3 +51,28 @@ The latter deep-learning framework is supposed to be more efficient than the seg * **LowRes_13434_overlapping_pairs.h5** : 13434 pairs of overlapping chromosomes generated from the two previous images. This dataset is intended to train a supervised learning algorithm to resolve overlapping chromosomes. The dataset is stored as a numpy array and saved in a hdf5 file. Compared to the DAPI and Cy3 images,the resolution was decreased by two. * **overlapping_chromosomes_examples.h5**: smaller dataset (~2000 images). The resolution of the images is the same than the DAPI/Cy3 images. * **UltraSmall-COCO-Dataset_125_images-json.zip** a very small dataset with 125 images of overlapping chromosomes and their annotation file in COCO format generated online with makesense.ai + +## Cytogenetic image multi-channel viewer + +The repository now ships a full PyQt5 application dedicated to cytogenetic imaging workflows. Slides are represented by folders that contain numbered metaphases, themselves populated with per-channel sub-directories (e.g. `dapi/`, `cy3/`, `cy5/`, `fitc/`). The viewer arranges the content as a digital *table lumineuse* and exposes dedicated tooling for preprocessing, alignment, segmentation, configuration, and batch execution. + +### Launching the application + +```bash +python -m viewer /chemin/vers/votre/dataset +``` + +If no argument is supplied the viewer scans a list of common dataset roots (`jpp21_downloaded_data`, `Raw images/jpp21`, `dataset`, the current folder) and, failing that, generates a small synthetic dataset so the interface remains functional out-of-the-box. + +> ℹ️ **Wayland users :** sous GNOME Wayland, forcez `QT_QPA_PLATFORM=wayland` (ou `=xcb`) avant le lancement pour supprimer l'avertissement Qt concernant `XDG_SESSION_TYPE`. + +### Interface overview + +The main window is split into an image canvas with navigation controls (slide/metaphase selectors and per-channel toggles) and a side panel containing four tabs: + +* **Preprocessing** – per-channel white top-hat filtering, background subtraction (median or mode), percentile-based histogram stretching, display mode selection (grayscale, inverted, pseudo-colour) and HSV colour pickers for the latter. Results can be saved on disk, and channel-specific parameters are persisted in `config.json`. +* **Alignment** – manual registration of every spectral component relative to DAPI using arrow buttons or numeric spin boxes. Offsets are stored in `alignment.json` and can be propagated to all metaphases on the slide. +* **Segmentation** – adaptive thresholding (`cv2.adaptiveThreshold`) with adjustable block size and offset, live previews, overlay/contour/filled display modes, undo/redo enabled brush and eraser tools, object size filtering and optional removal of border-touching regions. Masks can be saved as TIFF/PNG files. +* **Batch** – select a dataset root, tick preprocessing/alignment/segmentation steps, and run the pipeline for every slide and metaphase while tracking progress and textual logs. + +All parameters edited through the interface are written back to the slide-level `config.json`/`alignment.json` files when the application closes, making it straightforward to resume a session. The same configuration is reused by the batch executor so manual fine-tuning directly feeds the automated pipeline. diff --git a/cyto_viewer/__init__.py b/cyto_viewer/__init__.py new file mode 100644 index 00000000000..df9c9dc6289 --- /dev/null +++ b/cyto_viewer/__init__.py @@ -0,0 +1,3 @@ +"""Cytogenetic image viewer package.""" + +__all__ = ["viewer"] diff --git a/cyto_viewer/__main__.py b/cyto_viewer/__main__.py new file mode 100644 index 00000000000..4e7d64285a1 --- /dev/null +++ b/cyto_viewer/__main__.py @@ -0,0 +1,5 @@ +from .viewer import main + + +if __name__ == "__main__": + main() diff --git a/cyto_viewer/viewer.py b/cyto_viewer/viewer.py new file mode 100644 index 00000000000..0344ecede2c --- /dev/null +++ b/cyto_viewer/viewer.py @@ -0,0 +1,558 @@ +"""PyQt-based viewer for cytogenetic metaphase images. + +This module implements a navigation interface similar to a light table +that lets users browse 12-bit spectral components (DAPI, Cy3, Cy5) and a +pseudo-colour composite. The implementation focusses on providing +adjustable linear histogram stretching per channel and specialised +colour assignments in HSV space instead of simply stacking channels in +RGB. +""" +from __future__ import annotations + +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Tuple + +import numpy as np +from matplotlib.colors import hsv_to_rgb +from skimage import exposure, io + +from PyQt5 import QtCore, QtGui, QtWidgets + +CHANNEL_ORDER = ["dapi", "cy3", "cy5"] +CHANNEL_NAMES = {"dapi": "DAPI", "cy3": "Cy3", "cy5": "Cy5"} + +# HSV hues assigned to each spectral component (degrees / 360). +CHANNEL_HUES = { + "dapi": 210.0 / 360.0, # teal leaning blue + "cy3": 40.0 / 360.0, # orange-yellow + "cy5": 350.0 / 360.0, # cherry red +} + + +@dataclass +class HistogramSettings: + """Configuration for linear histogram stretching.""" + + low_percent: float = 1.0 + high_percent: float = 99.0 + + def clamp(self) -> None: + self.low_percent = max(0.0, min(self.low_percent, 100.0)) + self.high_percent = max(0.0, min(self.high_percent, 100.0)) + if self.high_percent <= self.low_percent: + # Maintain at least 1 % separation to avoid degenerate scaling. + if self.low_percent >= 99.0: + self.low_percent = 98.0 + self.high_percent = 99.0 + else: + self.high_percent = min(100.0, self.low_percent + 1.0) + + +@dataclass +class Metaphase: + """Represents one metaphase folder and its spectral components.""" + + name: str + channels: Dict[str, List[Path]] + + def available_channels(self) -> Iterable[str]: + return (ch for ch in CHANNEL_ORDER if ch in self.channels) + + +class ImageRepository: + """Loads and caches 12-bit TIFF images from a dataset hierarchy.""" + + def __init__(self, root: Path) -> None: + self.root = root + + def metaphases(self) -> List[Metaphase]: + entries: List[Metaphase] = [] + for path in sorted(self.root.iterdir()): + if not path.is_dir(): + continue + channels: Dict[str, List[Path]] = {} + for channel in CHANNEL_ORDER: + chan_dir = path / channel + if not chan_dir.exists() or not chan_dir.is_dir(): + continue + files = sorted( + p for p in chan_dir.iterdir() if p.suffix.lower() in {".tif", ".tiff"} + ) + if files: + channels[channel] = files + if channels: + entries.append(Metaphase(path.name, channels)) + return entries + + @lru_cache(maxsize=256) + def load_image(self, path: Path) -> np.ndarray: + data = io.imread(str(path)) + if data.ndim != 2: + raise ValueError(f"Expected 2D image for {path}, got shape {data.shape}") + return data.astype(np.float32) + + +class ThumbnailCache: + """Caches generated thumbnails to avoid recomputing on every refresh.""" + + def __init__(self) -> None: + self._cache: Dict[Tuple[str, str, int, float, float], QtGui.QPixmap] = {} + + def get( + self, + key: Tuple[str, str, int, float, float], + ) -> Optional[QtGui.QPixmap]: + return self._cache.get(key) + + def store( + self, + key: Tuple[str, str, int, float, float], + pixmap: QtGui.QPixmap, + ) -> None: + self._cache[key] = pixmap + + def clear(self) -> None: + self._cache.clear() + + +def apply_histogram_stretch( + image: np.ndarray, settings: HistogramSettings +) -> np.ndarray: + """Apply a linear histogram stretch based on percentile cut-offs.""" + + # Ensure we always operate on a float array without mutating the cached source. + working = image.astype(np.float32, copy=False) + + settings.clamp() + with np.errstate(invalid="ignore"): + low, high = np.nanpercentile( + working, [settings.low_percent, settings.high_percent] + ) + + if not np.isfinite(low) or not np.isfinite(high): + return np.zeros_like(working) + + if high < low: + low, high = high, low + + if np.isclose(high, low): + clipped = np.clip(working, low, high) + if high <= 0: + return np.zeros_like(working) + return (clipped - low) / max(high - low, 1e-6) + + try: + stretched = exposure.rescale_intensity( + working, in_range=(low, high), out_range=(0.0, 1.0) + ) + except Exception: + clipped = np.clip(working, low, high) + return (clipped - low) / max(high - low, 1e-6) + + # ``rescale_intensity`` may return a view with the original dtype. Guarantee float32. + return stretched.astype(np.float32, copy=False) + + +def channel_to_qimage(image: np.ndarray) -> QtGui.QImage: + """Convert a normalised single-channel image into a grayscale QImage.""" + + arr = np.clip(image, 0.0, 1.0) + data = (arr * 255).astype(np.uint8) + height, width = data.shape + bytes_per_line = width + return QtGui.QImage( + data.data, width, height, bytes_per_line, QtGui.QImage.Format_Grayscale8 + ).copy() + + +def rgb_array_to_qimage(image: np.ndarray) -> QtGui.QImage: + """Convert a normalised RGB array into a QImage.""" + + arr = np.clip(image, 0.0, 1.0) + data = (arr * 255).astype(np.uint8) + height, width, _ = data.shape + bytes_per_line = 3 * width + return QtGui.QImage( + data.data, width, height, bytes_per_line, QtGui.QImage.Format_RGB888 + ).copy() + + +def make_pseudo_colour( + stretched_images: Dict[str, np.ndarray], + hues: Dict[str, float] = CHANNEL_HUES, +) -> np.ndarray: + """Generate a pseudo-colour composite by tinting each component in HSV.""" + + rgb_layers: List[np.ndarray] = [] + for channel in CHANNEL_ORDER: + image = stretched_images.get(channel) + if image is None: + continue + hue = hues.get(channel, 0.0) + hsv = np.stack( + [np.full_like(image, hue), np.ones_like(image), image], axis=-1 + ) + rgb = hsv_to_rgb(hsv) + rgb_layers.append(rgb) + if not rgb_layers: + raise ValueError("No channels provided for pseudo-colour composite") + composite = np.zeros_like(rgb_layers[0]) + for layer in rgb_layers: + composite += layer + return np.clip(composite, 0.0, 1.0) + + +class ChannelControl(QtWidgets.QGroupBox): + """Widget with percentile controls for a single channel.""" + + settings_changed = QtCore.pyqtSignal(str, HistogramSettings) + + def __init__(self, channel: str, settings: HistogramSettings, parent: Optional[QtWidgets.QWidget] = None) -> None: + super().__init__(CHANNEL_NAMES.get(channel, channel), parent) + self.channel = channel + self.settings = settings + self.low_spin = QtWidgets.QDoubleSpinBox() + self.high_spin = QtWidgets.QDoubleSpinBox() + for spin in (self.low_spin, self.high_spin): + spin.setRange(0.0, 100.0) + spin.setDecimals(2) + spin.setSingleStep(0.25) + self.low_spin.setValue(settings.low_percent) + self.high_spin.setValue(settings.high_percent) + layout = QtWidgets.QFormLayout(self) + layout.addRow(self.tr("Seuil bas (%)"), self.low_spin) + layout.addRow(self.tr("Seuil haut (%)"), self.high_spin) + self.low_spin.valueChanged.connect(self._emit_changed) + self.high_spin.valueChanged.connect(self._emit_changed) + + def _emit_changed(self) -> None: + self.settings.low_percent = self.low_spin.value() + self.settings.high_percent = self.high_spin.value() + self.settings.clamp() + # Update spin boxes in case clamp adjusted the values. + self.low_spin.blockSignals(True) + self.high_spin.blockSignals(True) + self.low_spin.setValue(self.settings.low_percent) + self.high_spin.setValue(self.settings.high_percent) + self.low_spin.blockSignals(False) + self.high_spin.blockSignals(False) + self.settings_changed.emit(self.channel, self.settings) + + +class HistogramPanel(QtWidgets.QWidget): + """Panel hosting controls for all channel histogram settings.""" + + settings_changed = QtCore.pyqtSignal() + + def __init__(self, parent: Optional[QtWidgets.QWidget] = None) -> None: + super().__init__(parent) + self.controls: Dict[str, ChannelControl] = {} + layout = QtWidgets.QHBoxLayout(self) + for channel in CHANNEL_ORDER: + control = ChannelControl(channel, HistogramSettings()) + self.controls[channel] = control + control.settings_changed.connect(self._relay) + layout.addWidget(control) + layout.addStretch(1) + + def get_settings(self) -> Dict[str, HistogramSettings]: + return {ch: control.settings for ch, control in self.controls.items()} + + def _relay(self, channel: str, settings: HistogramSettings) -> None: + del channel, settings + self.settings_changed.emit() + + +class LightTable(QtWidgets.QMainWindow): + """Main application window implementing the light-table viewer.""" + + def __init__(self, repo: ImageRepository, parent: Optional[QtWidgets.QWidget] = None) -> None: + super().__init__(parent) + self.setWindowTitle("Cytogenetic Light Table Viewer") + self.repo = repo + self.metaphases = repo.metaphases() + self.thumbnail_cache = ThumbnailCache() + self.hist_panel = HistogramPanel() + self.hist_panel.settings_changed.connect(self.refresh_previews) + + self.table = QtWidgets.QTableWidget() + self.table.setColumnCount(len(CHANNEL_ORDER) + 1) + headers = [CHANNEL_NAMES[ch] for ch in CHANNEL_ORDER] + [self.tr("Pseudo-couleur")] + self.table.setHorizontalHeaderLabels(headers) + self.table.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectItems) + self.table.setIconSize(QtCore.QSize(160, 160)) + self.table.horizontalHeader().setSectionResizeMode(QtWidgets.QHeaderView.ResizeToContents) + self.table.verticalHeader().setVisible(False) + self.table.itemSelectionChanged.connect(self._update_preview_from_selection) + + self.preview_label = QtWidgets.QLabel(self.tr("Sélectionnez une image.")) + self.preview_label.setAlignment(QtCore.Qt.AlignCenter) + self.preview_label.setMinimumSize(320, 320) + self.preview_label.setBackgroundRole(QtGui.QPalette.Base) + self.preview_label.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) + + self.file_list = QtWidgets.QListWidget() + self.file_list.itemSelectionChanged.connect(self._file_index_changed) + + right_pane = QtWidgets.QWidget() + right_layout = QtWidgets.QVBoxLayout(right_pane) + right_layout.addWidget(self.preview_label, stretch=3) + right_layout.addWidget(QtWidgets.QLabel(self.tr("Fichiers disponibles"))) + right_layout.addWidget(self.file_list, stretch=1) + right_layout.addWidget(self.hist_panel, stretch=0) + + splitter = QtWidgets.QSplitter() + splitter.addWidget(self.table) + splitter.addWidget(right_pane) + splitter.setStretchFactor(0, 1) + splitter.setStretchFactor(1, 1) + + self.setCentralWidget(splitter) + + self._metaphase_indices: Dict[str, Dict[str, int]] = { + meta.name: {channel: 0 for channel in meta.channels} + for meta in self.metaphases + } + + self.populate_table() + + def populate_table(self) -> None: + self.table.setRowCount(len(self.metaphases)) + for row, meta in enumerate(self.metaphases): + name_item = QtWidgets.QTableWidgetItem(meta.name) + name_item.setFlags(QtCore.Qt.ItemIsEnabled) + self.table.setVerticalHeaderItem(row, name_item) + for col, channel in enumerate(CHANNEL_ORDER): + pixmap = self._make_channel_thumbnail(meta, channel) + item = QtWidgets.QTableWidgetItem() + if pixmap is not None: + item.setIcon(QtGui.QIcon(pixmap)) + item.setToolTip(f"{meta.name} — {CHANNEL_NAMES[channel]}") + else: + item.setText(self.tr("(absent)")) + item.setFlags(QtCore.Qt.NoItemFlags) + self.table.setItem(row, col, item) + pseudo_pixmap = self._make_pseudo_thumbnail(meta) + pseudo_item = QtWidgets.QTableWidgetItem() + if pseudo_pixmap is not None: + pseudo_item.setIcon(QtGui.QIcon(pseudo_pixmap)) + pseudo_item.setToolTip(f"{meta.name} — {self.tr('Pseudo-couleur')}") + else: + pseudo_item.setText(self.tr("(indisponible)")) + pseudo_item.setFlags(QtCore.Qt.NoItemFlags) + self.table.setItem(row, len(CHANNEL_ORDER), pseudo_item) + self.table.resizeRowsToContents() + + def _make_channel_thumbnail(self, meta: Metaphase, channel: str) -> Optional[QtGui.QPixmap]: + files = meta.channels.get(channel) + if not files: + return None + key = (meta.name, channel, 0, self.hist_panel.controls[channel].settings.low_percent, + self.hist_panel.controls[channel].settings.high_percent) + cached = self.thumbnail_cache.get(key) + if cached is not None: + return cached + image = self.repo.load_image(files[0]) + stretched = apply_histogram_stretch(image, self.hist_panel.controls[channel].settings) + pixmap = self._array_to_pixmap(stretched) + self.thumbnail_cache.store(key, pixmap) + return pixmap + + def _make_pseudo_thumbnail(self, meta: Metaphase) -> Optional[QtGui.QPixmap]: + stretched: Dict[str, np.ndarray] = {} + for channel in CHANNEL_ORDER: + files = meta.channels.get(channel) + if not files: + continue + image = self.repo.load_image(files[0]) + stretched[channel] = apply_histogram_stretch( + image, self.hist_panel.controls[channel].settings + ) + if not stretched: + return None + composite = make_pseudo_colour(stretched) + pixmap = self._array_to_pixmap(composite) + return pixmap + + def _array_to_pixmap(self, array: np.ndarray) -> QtGui.QPixmap: + if array.ndim == 2: + qimage = channel_to_qimage(array) + else: + qimage = rgb_array_to_qimage(array) + pixmap = QtGui.QPixmap.fromImage(qimage) + return pixmap.scaled( + self.table.iconSize(), + QtCore.Qt.KeepAspectRatio, + QtCore.Qt.SmoothTransformation, + ) + + def refresh_previews(self) -> None: + self.thumbnail_cache.clear() + self.populate_table() + self._update_preview_from_selection() + + def _selected_cell(self) -> Optional[Tuple[int, int]]: + items = self.table.selectedIndexes() + if not items: + return None + index = items[0] + return index.row(), index.column() + + def _update_preview_from_selection(self) -> None: + selection = self._selected_cell() + if selection is None: + return + row, column = selection + if row >= len(self.metaphases): + return + meta = self.metaphases[row] + if column < len(CHANNEL_ORDER): + channel = CHANNEL_ORDER[column] + self._show_channel(meta, channel) + else: + self._show_pseudo(meta) + + def _show_channel(self, meta: Metaphase, channel: str) -> None: + files = meta.channels.get(channel) + if not files: + self.preview_label.setText(self.tr("Canal indisponible")) + self.file_list.clear() + return + index = self._metaphase_indices[meta.name].get(channel, 0) + index = max(0, min(index, len(files) - 1)) + self._metaphase_indices[meta.name][channel] = index + self.file_list.blockSignals(True) + self.file_list.clear() + for i, file in enumerate(files): + item = QtWidgets.QListWidgetItem(file.name) + item.setData(QtCore.Qt.UserRole, i) + self.file_list.addItem(item) + if i == index: + item.setSelected(True) + self.file_list.blockSignals(False) + image = self.repo.load_image(files[index]) + stretched = apply_histogram_stretch(image, self.hist_panel.controls[channel].settings) + pixmap = QtGui.QPixmap.fromImage(channel_to_qimage(stretched)) + self.preview_label.setPixmap( + pixmap.scaled( + self.preview_label.size(), + QtCore.Qt.KeepAspectRatio, + QtCore.Qt.SmoothTransformation, + ) + ) + self.preview_label.setToolTip(f"{meta.name} — {CHANNEL_NAMES[channel]} — {files[index].name}") + + def _show_pseudo(self, meta: Metaphase) -> None: + self.file_list.blockSignals(True) + self.file_list.clear() + self.file_list.blockSignals(False) + stretched: Dict[str, np.ndarray] = {} + for channel in CHANNEL_ORDER: + files = meta.channels.get(channel) + if not files: + continue + index = self._metaphase_indices[meta.name].get(channel, 0) + index = max(0, min(index, len(files) - 1)) + image = self.repo.load_image(files[index]) + stretched[channel] = apply_histogram_stretch( + image, self.hist_panel.controls[channel].settings + ) + if not stretched: + self.preview_label.setText(self.tr("Pas de données pour la pseudo-couleur")) + return + composite = make_pseudo_colour(stretched) + pixmap = QtGui.QPixmap.fromImage(rgb_array_to_qimage(composite)) + self.preview_label.setPixmap( + pixmap.scaled( + self.preview_label.size(), + QtCore.Qt.KeepAspectRatio, + QtCore.Qt.SmoothTransformation, + ) + ) + self.preview_label.setToolTip(f"{meta.name} — {self.tr('Pseudo-couleur')}") + + def resizeEvent(self, event: QtGui.QResizeEvent) -> None: # noqa: N802 + super().resizeEvent(event) + self._update_preview_from_selection() + + def _file_index_changed(self) -> None: + selection = self._selected_cell() + if selection is None: + return + row, column = selection + meta = self.metaphases[row] + if column >= len(CHANNEL_ORDER): + return + channel = CHANNEL_ORDER[column] + selected_items = self.file_list.selectedItems() + if not selected_items: + return + index = selected_items[0].data(QtCore.Qt.UserRole) + self._metaphase_indices[meta.name][channel] = index + self._show_channel(meta, channel) + + +def run_viewer(root: Path) -> None: + """Start the Qt application for the provided dataset root.""" + + app = QtWidgets.QApplication.instance() + if app is None: + app = QtWidgets.QApplication([]) + repo = ImageRepository(root) + window = LightTable(repo) + if not window.metaphases: + QtWidgets.QMessageBox.warning( + window, + window.tr("Données manquantes"), + window.tr("Aucune métaphase n'a été trouvée dans le dossier fourni."), + ) + window.resize(1200, 700) + window.show() + app.exec_() + + +def _candidate_roots(start: Path) -> Iterable[Path]: + """Yield plausible dataset folders relative to ``start`` and its parents.""" + + dataset_names = [ + Path("Raw images") / "jpp21", + Path("dataset"), + Path("jpp21_downloaded_data"), + Path("jpp21"), + ] + for base in [start, *start.parents]: + for name in dataset_names: + candidate = (base / name).resolve() + yield candidate + yield base.resolve() + + +def find_default_root() -> Path: + """Try to guess a default dataset location.""" + + for candidate in _candidate_roots(Path.cwd()): + if candidate.exists() and candidate.is_dir(): + return candidate + return Path.cwd() + + +def main(argv: Optional[List[str]] = None) -> None: + import argparse + + parser = argparse.ArgumentParser(description="Viewer for cytogenetic spectral components") + parser.add_argument( + "root", + nargs="?", + default=str(find_default_root()), + help="Folder containing metaphase sub-directories", + ) + args, unknown = parser.parse_known_args(argv) + if unknown: + print(f"Ignoring unrecognised Qt arguments: {unknown}") + run_viewer(Path(args.root)) + + +if __name__ == "__main__": + main() diff --git a/viewer/__init__.py b/viewer/__init__.py new file mode 100644 index 00000000000..08c9c0d1f5f --- /dev/null +++ b/viewer/__init__.py @@ -0,0 +1,5 @@ +"""Cytogenetic Image Multi-Channel Viewer package.""" + +from .main import main + +__all__ = ["main"] diff --git a/viewer/__main__.py b/viewer/__main__.py new file mode 100644 index 00000000000..40e2b013f61 --- /dev/null +++ b/viewer/__main__.py @@ -0,0 +1,4 @@ +from .main import main + +if __name__ == "__main__": + main() diff --git a/viewer/alignment.py b/viewer/alignment.py new file mode 100644 index 00000000000..3793bf3b76d --- /dev/null +++ b/viewer/alignment.py @@ -0,0 +1,97 @@ +"""Manual alignment controls.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Tuple + +import numpy as np +from PyQt5 import QtCore, QtWidgets + + +@dataclass +class AlignmentState: + offsets: Dict[str, Tuple[int, int]] + + def to_dict(self) -> Dict[str, Tuple[int, int]]: + return {k: list(v) for k, v in self.offsets.items()} + + @classmethod + def from_dict(cls, data: Dict[str, Tuple[int, int]]) -> "AlignmentState": + return cls(offsets={k: tuple(v) for k, v in data.items()}) + + +class AlignmentPanel(QtWidgets.QWidget): + translationChanged = QtCore.pyqtSignal(str, int, int) + applyAllRequested = QtCore.pyqtSignal() + + def __init__(self, parent: QtWidgets.QWidget | None = None) -> None: + super().__init__(parent) + self._channel = "cy3" + self._offsets: Dict[str, Tuple[int, int]] = {} + self._build_ui() + + def _build_ui(self) -> None: + layout = QtWidgets.QVBoxLayout(self) + self.label = QtWidgets.QLabel("Channel: cy3 relative to DAPI") + layout.addWidget(self.label) + + grid = QtWidgets.QGridLayout() + self.dy_spin = QtWidgets.QSpinBox() + self.dy_spin.setRange(-200, 200) + self.dx_spin = QtWidgets.QSpinBox() + self.dx_spin.setRange(-200, 200) + grid.addWidget(QtWidgets.QLabel("dy"), 0, 0) + grid.addWidget(self.dy_spin, 0, 1) + grid.addWidget(QtWidgets.QLabel("dx"), 1, 0) + grid.addWidget(self.dx_spin, 1, 1) + layout.addLayout(grid) + + arrows = QtWidgets.QGridLayout() + self.btn_up = QtWidgets.QPushButton("↑") + self.btn_down = QtWidgets.QPushButton("↓") + self.btn_left = QtWidgets.QPushButton("←") + self.btn_right = QtWidgets.QPushButton("→") + arrows.addWidget(self.btn_up, 0, 1) + arrows.addWidget(self.btn_left, 1, 0) + arrows.addWidget(self.btn_right, 1, 2) + arrows.addWidget(self.btn_down, 2, 1) + layout.addLayout(arrows) + + self.apply_all_button = QtWidgets.QPushButton("Apply to all metaphases") + layout.addWidget(self.apply_all_button) + layout.addStretch() + + self.dy_spin.valueChanged.connect(self._emit_change) + self.dx_spin.valueChanged.connect(self._emit_change) + self.btn_up.clicked.connect(lambda: self.dy_spin.setValue(self.dy_spin.value() - 1)) + self.btn_down.clicked.connect(lambda: self.dy_spin.setValue(self.dy_spin.value() + 1)) + self.btn_left.clicked.connect(lambda: self.dx_spin.setValue(self.dx_spin.value() - 1)) + self.btn_right.clicked.connect(lambda: self.dx_spin.setValue(self.dx_spin.value() + 1)) + self.apply_all_button.clicked.connect(self.applyAllRequested) + + def set_channel(self, channel: str, data: Dict[str, Tuple[int, int]]) -> None: + self._channel = channel + self.label.setText(f"Channel: {channel} relative to DAPI") + if channel not in self._offsets: + self._offsets[channel] = (0, 0) + if data and channel in data: + self._offsets[channel] = tuple(data[channel]) + self.dy_spin.blockSignals(True) + self.dx_spin.blockSignals(True) + dy, dx = self._offsets[channel] + self.dy_spin.setValue(int(dy)) + self.dx_spin.setValue(int(dx)) + self.dy_spin.blockSignals(False) + self.dx_spin.blockSignals(False) + + def export_offsets(self) -> Dict[str, Tuple[int, int]]: + return {ch: (int(dy), int(dx)) for ch, (dy, dx) in self._offsets.items()} + + def _emit_change(self) -> None: + offsets = (self.dy_spin.value(), self.dx_spin.value()) + self._offsets[self._channel] = offsets + self.translationChanged.emit(self._channel, *offsets) + + +def apply_translation(image: np.ndarray, dy: int, dx: int) -> np.ndarray: + return np.roll(np.roll(image, int(dy), axis=0), int(dx), axis=1) diff --git a/viewer/batch.py b/viewer/batch.py new file mode 100644 index 00000000000..69caf9eef8f --- /dev/null +++ b/viewer/batch.py @@ -0,0 +1,76 @@ +"""Batch processing helpers.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List + +from PyQt5 import QtCore, QtWidgets + + +@dataclass +class BatchRequest: + root: str + steps: Dict[str, bool] + + +class BatchPanel(QtWidgets.QWidget): + runRequested = QtCore.pyqtSignal(object) + + def __init__(self, parent: QtWidgets.QWidget | None = None) -> None: + super().__init__(parent) + self._build_ui() + + def _build_ui(self) -> None: + layout = QtWidgets.QVBoxLayout(self) + + self.path_edit = QtWidgets.QLineEdit() + self.choose_button = QtWidgets.QPushButton("Browse…") + path_layout = QtWidgets.QHBoxLayout() + path_layout.addWidget(self.path_edit) + path_layout.addWidget(self.choose_button) + layout.addLayout(path_layout) + + self.preprocess_check = QtWidgets.QCheckBox("Preprocessing") + self.align_check = QtWidgets.QCheckBox("Alignment") + self.segment_check = QtWidgets.QCheckBox("Segmentation") + for chk in (self.preprocess_check, self.align_check, self.segment_check): + chk.setChecked(True) + layout.addWidget(chk) + + self.run_button = QtWidgets.QPushButton("Run batch") + layout.addWidget(self.run_button) + + self.progress = QtWidgets.QProgressBar() + self.progress.setRange(0, 100) + layout.addWidget(self.progress) + + self.log = QtWidgets.QPlainTextEdit() + self.log.setReadOnly(True) + layout.addWidget(self.log) + + layout.addStretch() + + self.choose_button.clicked.connect(self._choose_directory) + self.run_button.clicked.connect(self._emit_request) + + def _choose_directory(self) -> None: + path = QtWidgets.QFileDialog.getExistingDirectory(self, "Select dataset root") + if path: + self.path_edit.setText(path) + + def _emit_request(self) -> None: + request = BatchRequest( + root=self.path_edit.text(), + steps={ + "preprocessing": self.preprocess_check.isChecked(), + "alignment": self.align_check.isChecked(), + "segmentation": self.segment_check.isChecked(), + }, + ) + self.runRequested.emit(request) + + def log_message(self, message: str) -> None: + self.log.appendPlainText(message) + + def update_progress(self, percent: int) -> None: + self.progress.setValue(percent) diff --git a/viewer/config.py b/viewer/config.py new file mode 100644 index 00000000000..a61045e5be3 --- /dev/null +++ b/viewer/config.py @@ -0,0 +1,49 @@ +"""JSON based persistence for viewer parameters.""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict + + +CONFIG_FILENAME = "config.json" +ALIGNMENT_FILENAME = "alignment.json" + + +@dataclass +class Config: + preprocessing: Dict[str, Any] + alignment: Dict[str, Any] + segmentation: Dict[str, Any] + + +def load_config(slide_path: Path) -> Config: + config_path = slide_path / CONFIG_FILENAME + alignment_path = slide_path / ALIGNMENT_FILENAME + + def _read(path: Path) -> Dict[str, Any]: + if path.exists(): + try: + return json.loads(path.read_text()) + except json.JSONDecodeError: + return {} + return {} + + data = _read(config_path) + alignment = _read(alignment_path) + return Config( + preprocessing=data.get("preprocessing", {}), + alignment=alignment, + segmentation=data.get("segmentation", {}), + ) + + +def save_config(slide_path: Path, config: Config) -> None: + config_path = slide_path / CONFIG_FILENAME + alignment_path = slide_path / ALIGNMENT_FILENAME + config_path.write_text(json.dumps({ + "preprocessing": config.preprocessing, + "segmentation": config.segmentation, + }, indent=2)) + alignment_path.write_text(json.dumps(config.alignment, indent=2)) diff --git a/viewer/main.py b/viewer/main.py new file mode 100644 index 00000000000..04baa6f2a71 --- /dev/null +++ b/viewer/main.py @@ -0,0 +1,545 @@ +"""Entry point for the Cytogenetic Image Multi-Channel Viewer.""" +from __future__ import annotations + +import functools +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +from PyQt5 import QtCore, QtGui, QtWidgets + +from . import alignment, batch, config, preprocessing, segmentation +from .utils import data +from skimage import morphology + + +def numpy_to_qimage(arr: np.ndarray) -> QtGui.QImage: + """Convert a float image (0-1) or uint8/uint16 to a QImage.""" + if arr.ndim == 2: + if arr.dtype != np.uint8: + norm = np.clip(arr, 0, 1) + arr8 = (norm * 255).astype(np.uint8) + else: + arr8 = arr + arr8 = np.ascontiguousarray(arr8) + h, w = arr8.shape + return QtGui.QImage(arr8.data, w, h, w, QtGui.QImage.Format_Grayscale8).copy() + if arr.ndim == 3 and arr.shape[2] == 3: + if arr.dtype != np.uint8: + norm = np.clip(arr, 0, 1) + arr8 = (norm * 255).astype(np.uint8) + else: + arr8 = arr + arr8 = np.ascontiguousarray(arr8) + h, w, _ = arr8.shape + return QtGui.QImage(arr8.data, w, h, w * 3, QtGui.QImage.Format_RGB888).copy() + raise ValueError("Unsupported array shape for QImage") + + +class ImageCanvas(QtWidgets.QGraphicsView): + maskEdited = QtCore.pyqtSignal(object) + + def __init__(self, parent: QtWidgets.QWidget | None = None) -> None: + super().__init__(parent) + self.setScene(QtWidgets.QGraphicsScene(self)) + self._pixmap_item = QtWidgets.QGraphicsPixmapItem() + self.scene().addItem(self._pixmap_item) + self._mask_item = QtWidgets.QGraphicsPixmapItem() + self._mask_item.setOpacity(0.5) + self.scene().addItem(self._mask_item) + self._image_array: Optional[np.ndarray] = None + self._mask: Optional[np.ndarray] = None + self._display_mode = "overlay" + self._brush_radius = 8 + self._tool = "brush" + self._editing = False + self._undo_stack: List[np.ndarray] = [] + self._redo_stack: List[np.ndarray] = [] + self.setRenderHints(QtGui.QPainter.Antialiasing | QtGui.QPainter.SmoothPixmapTransform) + + # ------------------------------------------------------------------ + def set_image(self, qimage: QtGui.QImage, array: np.ndarray) -> None: + self._image_array = array + self._pixmap_item.setPixmap(QtGui.QPixmap.fromImage(qimage)) + self.scene().setSceneRect(self._pixmap_item.boundingRect()) + self.fitInView(self._pixmap_item, QtCore.Qt.KeepAspectRatio) + + def set_mask(self, mask: Optional[np.ndarray]) -> None: + if mask is None: + self._mask = None + self._mask_item.setPixmap(QtGui.QPixmap()) + self._undo_stack.clear() + self._redo_stack.clear() + return + self._mask = mask.astype(bool) + self._undo_stack.clear() + self._redo_stack.clear() + self._update_mask_pixmap() + + def set_display_mode(self, mode: str) -> None: + self._display_mode = mode + self._update_mask_pixmap() + + def set_brush(self, radius: int, tool: str) -> None: + self._brush_radius = max(1, int(radius)) + self._tool = tool + + def set_editing_enabled(self, enabled: bool) -> None: + self._editing = enabled + + def undo(self) -> None: + if not self._undo_stack: + return + current = self._mask.copy() + self._redo_stack.append(current) + prev = self._undo_stack.pop() + self._mask = prev + self._update_mask_pixmap() + self.maskEdited.emit(self._mask.copy()) + + def redo(self) -> None: + if not self._redo_stack: + return + self._undo_stack.append(self._mask.copy()) + next_mask = self._redo_stack.pop() + self._mask = next_mask + self._update_mask_pixmap() + self.maskEdited.emit(self._mask.copy()) + + # ------------------------------------------------------------------ + def mousePressEvent(self, event: QtGui.QMouseEvent) -> None: + if self._editing and self._mask is not None and event.buttons() & QtCore.Qt.LeftButton: + self._push_undo() + self._apply_brush(event.pos()) + super().mousePressEvent(event) + + def mouseMoveEvent(self, event: QtGui.QMouseEvent) -> None: + if self._editing and self._mask is not None and event.buttons() & QtCore.Qt.LeftButton: + self._apply_brush(event.pos()) + super().mouseMoveEvent(event) + + def _push_undo(self) -> None: + if self._mask is not None: + self._undo_stack.append(self._mask.copy()) + self._redo_stack.clear() + + def _apply_brush(self, pos: QtCore.QPoint) -> None: + if self._mask is None: + return + scene_pos = self.mapToScene(pos) + x = int(scene_pos.x()) + y = int(scene_pos.y()) + if not (0 <= x < self._mask.shape[1] and 0 <= y < self._mask.shape[0]): + return + yy, xx = np.ogrid[: self._mask.shape[0], : self._mask.shape[1]] + circle = (xx - x) ** 2 + (yy - y) ** 2 <= self._brush_radius ** 2 + if self._tool == "brush": + self._mask[circle] = True + else: + self._mask[circle] = False + self._update_mask_pixmap() + self.maskEdited.emit(self._mask.copy()) + + def _update_mask_pixmap(self) -> None: + if self._mask is None: + self._mask_item.setPixmap(QtGui.QPixmap()) + return + mask_bool = self._mask.astype(bool) + mask = mask_bool.astype(np.uint8) + if self._display_mode == "contour": + eroded = morphology.binary_erosion(mask_bool) + edges = mask_bool & ~eroded + rgb = np.zeros((*mask.shape, 3), dtype=np.uint8) + rgb[..., 0] = edges * 255 + rgb[..., 1] = edges * 255 + rgb[..., 2] = 0 + qimage = QtGui.QImage(rgb.data, rgb.shape[1], rgb.shape[0], QtGui.QImage.Format_RGB888) + self._mask_item.setPixmap(QtGui.QPixmap.fromImage(qimage)) + return + color = np.zeros((*mask.shape, 4), dtype=np.uint8) + if self._display_mode == "overlay": + color[..., 0] = 255 + color[..., 1] = 128 + color[..., 3] = mask * 120 + else: # filled + color[..., 0] = mask * 255 + color[..., 3] = mask * 255 + qimage = QtGui.QImage(color.data, color.shape[1], color.shape[0], QtGui.QImage.Format_RGBA8888) + self._mask_item.setPixmap(QtGui.QPixmap.fromImage(qimage)) + + +class MainWindow(QtWidgets.QMainWindow): + def __init__(self, dataset: data.Dataset) -> None: + super().__init__() + self.dataset = dataset + self.current_slide: Optional[data.Slide] = None + self.current_metaphase: Optional[data.Metaphase] = None + self.current_channel: str = "dapi" + self.config: Optional[config.Config] = None + self.raw_images: Dict[str, np.ndarray] = {} + self.preprocessed: Dict[str, np.ndarray] = {} + self.alignment_offsets: Dict[str, tuple] = {} + self.segmentation_masks: Dict[str, np.ndarray] = {} + self.segmentation_mask: Optional[np.ndarray] = None + self._build_ui() + self._populate_slides() + + # ------------------------------------------------------------------ + def _build_ui(self) -> None: + self.setWindowTitle("Cytogenetic Image Viewer") + central = QtWidgets.QWidget() + self.setCentralWidget(central) + + layout = QtWidgets.QHBoxLayout(central) + + left_container = QtWidgets.QWidget() + left_layout = QtWidgets.QVBoxLayout(left_container) + + navigation_layout = QtWidgets.QHBoxLayout() + self.slide_combo = QtWidgets.QComboBox() + self.metaphase_combo = QtWidgets.QComboBox() + navigation_layout.addWidget(QtWidgets.QLabel("Slide")) + navigation_layout.addWidget(self.slide_combo) + navigation_layout.addWidget(QtWidgets.QLabel("Metaphase")) + navigation_layout.addWidget(self.metaphase_combo) + left_layout.addLayout(navigation_layout) + + self.channel_buttons_layout = QtWidgets.QHBoxLayout() + left_layout.addLayout(self.channel_buttons_layout) + + self.canvas = ImageCanvas() + self.canvas.set_editing_enabled(False) + left_layout.addWidget(self.canvas, stretch=1) + + layout.addWidget(left_container, stretch=3) + + self.tabs = QtWidgets.QTabWidget() + self.preprocessing_panel = preprocessing.PreprocessingPanel() + self.alignment_panel = alignment.AlignmentPanel() + self.segmentation_panel = segmentation.SegmentationPanel() + self.batch_panel = batch.BatchPanel() + self.tabs.addTab(self.preprocessing_panel, "Preprocessing") + self.tabs.addTab(self.alignment_panel, "Alignment") + self.tabs.addTab(self.segmentation_panel, "Segmentation") + self.tabs.addTab(self.batch_panel, "Batch") + layout.addWidget(self.tabs, stretch=1) + + # Signals + self.slide_combo.currentTextChanged.connect(self._on_slide_changed) + self.metaphase_combo.currentTextChanged.connect(self._on_metaphase_changed) + self.preprocessing_panel.settingsChanged.connect(self._on_preprocessing_settings) + self.preprocessing_panel.applyRequested.connect(self._apply_preprocessing) + self.preprocessing_panel.batchRequested.connect(self._run_preprocessing_batch) + self.preprocessing_panel.saveRequested.connect(self._save_preprocessed) + self.alignment_panel.translationChanged.connect(self._on_alignment_changed) + self.alignment_panel.applyAllRequested.connect(self._apply_alignment_all) + self.segmentation_panel.generateRequested.connect(self._on_generate_segmentation) + self.segmentation_panel.saveRequested.connect(self._save_segmentation) + self.segmentation_panel.displayModeChanged.connect(self.canvas.set_display_mode) + self.segmentation_panel.brushChanged.connect(self._on_brush_changed) + self.segmentation_panel.undoRequested.connect(self.canvas.undo) + self.segmentation_panel.redoRequested.connect(self.canvas.redo) + self.segmentation_panel.filterRequested.connect(self._apply_segmentation_filters) + self.canvas.maskEdited.connect(self._on_mask_edited) + self.batch_panel.runRequested.connect(self._on_batch_requested) + + # ------------------------------------------------------------------ + def _populate_slides(self) -> None: + self.slide_combo.blockSignals(True) + self.slide_combo.clear() + for slide_name in self.dataset.available_slides(): + self.slide_combo.addItem(slide_name) + self.slide_combo.blockSignals(False) + if self.slide_combo.count() > 0: + self.slide_combo.setCurrentIndex(0) + self._on_slide_changed(self.slide_combo.currentText()) + + def _on_slide_changed(self, name: str) -> None: + if not name: + return + if self.current_slide is not None: + self._persist_current_config() + self.current_slide = self.dataset.slides[name] + self.config = config.load_config(self.current_slide.path) + self.alignment_offsets = {} + for k, v in self.config.alignment.items(): + if isinstance(v, (list, tuple)) and len(v) >= 2: + self.alignment_offsets[k] = (int(v[0]), int(v[1])) + self.preprocessing_panel.set_channel(self.current_channel, self.config.preprocessing) + self.alignment_panel.set_channel(self.current_channel, self.alignment_offsets) + self.segmentation_panel.set_channel(self.current_channel, self.config.segmentation) + self._populate_metaphases() + + def _populate_metaphases(self) -> None: + self.metaphase_combo.blockSignals(True) + self.metaphase_combo.clear() + if not self.current_slide: + return + for meta in self.current_slide.available_metaphases(): + self.metaphase_combo.addItem(meta) + self.metaphase_combo.blockSignals(False) + if self.metaphase_combo.count() > 0: + self.metaphase_combo.setCurrentIndex(0) + self._on_metaphase_changed(self.metaphase_combo.currentText()) + + def _on_metaphase_changed(self, name: str) -> None: + if not name or not self.current_slide: + return + self.current_metaphase = self.current_slide.metaphases[name] + self.raw_images = {} + self.preprocessed = {} + self.segmentation_masks = {} + self.segmentation_mask = None + for channel, files in self.current_metaphase.channels.items(): + image = data.load_image(files[0]) + self.raw_images[channel] = image + self._setup_channel_buttons() + if self.current_channel not in self.raw_images: + self.current_channel = self.current_metaphase.available_channels()[0] + self.preprocessing_panel.set_channel(self.current_channel, self.config.preprocessing) + self.alignment_panel.set_channel(self.current_channel, self.alignment_offsets) + self.segmentation_panel.set_channel(self.current_channel, self.config.segmentation) + self._refresh_view() + + def _setup_channel_buttons(self) -> None: + # Clear layout + while self.channel_buttons_layout.count(): + item = self.channel_buttons_layout.takeAt(0) + widget = item.widget() + if widget is not None: + widget.deleteLater() + if not self.current_metaphase: + return + for channel in self.current_metaphase.available_channels(): + btn = QtWidgets.QToolButton() + btn.setText(channel) + btn.setCheckable(True) + btn.setChecked(channel == self.current_channel) + btn.clicked.connect(functools.partial(self._on_channel_selected, channel)) + self.channel_buttons_layout.addWidget(btn) + + def _on_channel_selected(self, channel: str) -> None: + self.current_channel = channel + self.preprocessing_panel.set_channel(channel, self.config.preprocessing) + self.alignment_panel.set_channel(channel, self.alignment_offsets) + self.segmentation_panel.set_channel(channel, self.config.segmentation) + self.segmentation_mask = self.segmentation_masks.get(channel) + self._refresh_view() + + def _refresh_view(self) -> None: + if not self.current_metaphase: + return + processed = self._get_processed_image(self.current_channel) + settings = self.preprocessing_panel.settings_for(self.current_channel) + norm = processed + if norm.max() > 0: + norm = processed / processed.max() + if settings.display_mode == "inverted": + norm = 1.0 - norm + qimage = numpy_to_qimage(norm) + elif settings.display_mode == "false_color": + rgb = preprocessing.map_false_colour(norm, settings) + qimage = numpy_to_qimage(rgb) + else: + qimage = numpy_to_qimage(norm) + self.canvas.set_image(qimage, processed) + mask = self.segmentation_masks.get(self.current_channel) + self.segmentation_mask = mask + self.canvas.set_mask(mask) + self.canvas.set_editing_enabled(mask is not None) + + def _get_processed_image(self, channel: str) -> np.ndarray: + if channel in self.preprocessed: + return self.preprocessed[channel] + raw = self.raw_images[channel] + settings = self.preprocessing_panel.settings_for(channel) + processed = preprocessing.preprocess_image(raw, settings) + offsets = self.alignment_offsets.get(channel, (0, 0)) + processed = alignment.apply_translation(processed, offsets[0], offsets[1]) + self.preprocessed[channel] = processed + return processed + + def _on_preprocessing_settings(self, channel: str, settings: preprocessing.PreprocessingSettings) -> None: + if channel in self.preprocessed: + del self.preprocessed[channel] + self.config.preprocessing[channel] = settings.to_dict() + if channel == self.current_channel: + self._refresh_view() + + def _apply_preprocessing(self, channel: str) -> None: + self.preprocessed.pop(channel, None) + self._refresh_view() + + def _run_preprocessing_batch(self, channel: str) -> None: + if not self.current_slide: + return + metaphases = list(self.current_slide.metaphases.values()) + targets = [(meta, meta.channels.get(channel)) for meta in metaphases] + targets = [(meta, files) for meta, files in targets if files] + if not targets: + QtWidgets.QMessageBox.information(self, "Preprocessing", "No images found for this channel on the current slide.") + return + progress = QtWidgets.QProgressDialog("Preprocessing", "Cancel", 0, len(targets), self) + progress.setWindowTitle("Preprocessing batch") + progress.setWindowModality(QtCore.Qt.WindowModal) + settings = self.preprocessing_panel.settings_for(channel) + from skimage import io + + for idx, (metaphase, files) in enumerate(targets, start=1): + if progress.wasCanceled(): + break + image = data.load_image(files[0]) + result = preprocessing.preprocess_image(image, settings) + out_path = metaphase.path / f"{channel}_preprocessed.tif" + io.imsave(str(out_path), (np.clip(result, 0, 1) * 65535).astype(np.uint16)) + progress.setValue(idx) + progress.setLabelText(f"Saved {out_path.name}") + progress.close() + + def _save_preprocessed(self, channel: str) -> None: + if channel not in self.preprocessed: + self._refresh_view() + arr = self.preprocessed[channel] + path, _ = QtWidgets.QFileDialog.getSaveFileName(self, "Save preprocessed image", f"{channel}.tif") + if path: + from skimage import io + + max_val = float(arr.max()) if arr.size else 1.0 + norm = arr / max_val if max_val > 0 else arr + io.imsave(path, np.clip(norm, 0, 1)) + + def _on_alignment_changed(self, channel: str, dy: int, dx: int) -> None: + self.alignment_offsets[channel] = (dy, dx) + self.config.alignment[channel] = [dy, dx] + if channel in self.preprocessed: + del self.preprocessed[channel] + if channel == self.current_channel: + self._refresh_view() + + def _apply_alignment_all(self) -> None: + if not self.current_metaphase: + return + for channel in self.current_metaphase.available_channels(): + if channel == "dapi": + continue + offset = self.alignment_offsets.get(channel, (0, 0)) + self.config.alignment[channel] = [offset[0], offset[1]] + QtWidgets.QMessageBox.information(self, "Alignment", "Offsets saved for all channels in the slide.") + + def _on_generate_segmentation(self, channel: str, settings: segmentation.SegmentationSettings) -> None: + image = self._get_processed_image(channel) + mask = segmentation.adaptive_threshold(image, settings) + self.segmentation_mask = segmentation.apply_filters(mask, settings) + self.config.segmentation[channel] = settings.to_dict() + self.segmentation_masks[channel] = self.segmentation_mask + self.canvas.set_mask(self.segmentation_mask) + self.canvas.set_editing_enabled(True) + + def _apply_segmentation_filters(self) -> None: + if self.segmentation_mask is None: + return + settings = self.segmentation_panel.settings_for(self.current_channel) + self.segmentation_mask = segmentation.apply_filters(self.segmentation_mask, settings) + self.segmentation_masks[self.current_channel] = self.segmentation_mask + self.canvas.set_mask(self.segmentation_mask) + + def _save_segmentation(self, channel: str) -> None: + mask = self.segmentation_masks.get(channel) + if mask is None: + return + path, _ = QtWidgets.QFileDialog.getSaveFileName(self, "Save mask", f"{channel}_mask.tif") + if path: + from skimage import io + + io.imsave(path, (mask > 0).astype(np.uint8) * 255) + + def _on_brush_changed(self, radius: int, tool: str) -> None: + self.canvas.set_brush(radius, tool) + + def _on_mask_edited(self, mask: np.ndarray) -> None: + self.segmentation_mask = mask + self.segmentation_masks[self.current_channel] = mask + + def _persist_current_config(self) -> None: + if not self.current_slide: + return + self.alignment_offsets = self.alignment_panel.export_offsets() + cfg = config.Config( + preprocessing=self.preprocessing_panel.export_settings(), + alignment=self.alignment_offsets, + segmentation=self.segmentation_panel.export_settings(), + ) + config.save_config(self.current_slide.path, cfg) + + def closeEvent(self, event: QtGui.QCloseEvent) -> None: + self._persist_current_config() + super().closeEvent(event) + + def _on_batch_requested(self, request: batch.BatchRequest) -> None: + root = Path(request.root) if request.root else self.current_slide.path + dataset = data.discover_dataset(root) + if dataset is None: + QtWidgets.QMessageBox.warning(self, "Batch", "No dataset found at the selected path.") + return + slides = dataset.available_slides() + if not slides: + QtWidgets.QMessageBox.warning(self, "Batch", "Dataset does not contain any slides.") + return + total = len(slides) + for idx, slide_name in enumerate(slides, start=1): + slide = dataset.slides[slide_name] + slide_cfg = config.load_config(slide.path) + for metaphase in slide.metaphases.values(): + for channel, files in metaphase.channels.items(): + image = data.load_image(files[0]) + if request.steps["preprocessing"]: + settings = preprocessing.PreprocessingSettings.from_dict( + slide_cfg.preprocessing.get(channel, {}) + ) + image = preprocessing.preprocess_image(image, settings) + if request.steps["alignment"]: + dy, dx = slide_cfg.alignment.get(channel, (0, 0)) + image = alignment.apply_translation(image, dy, dx) + if request.steps["segmentation"]: + seg_settings = segmentation.SegmentationSettings.from_dict( + slide_cfg.segmentation.get(channel, {}) + ) + mask = segmentation.adaptive_threshold(image, seg_settings) + mask = segmentation.apply_filters(mask, seg_settings) + out_path = metaphase.path / f"{channel}_mask.png" + from skimage import io + + io.imsave(str(out_path), mask.astype(np.uint8) * 255) + percent = int(idx / total * 100) + self.batch_panel.update_progress(percent) + self.batch_panel.log_message(f"Processed slide {slide_name}") + QtWidgets.QMessageBox.information(self, "Batch", "Batch processing finished.") + + +def main() -> None: + import argparse + + parser = argparse.ArgumentParser(description="Cytogenetic image viewer") + parser.add_argument("root", nargs="?", default=None, help="Dataset root directory") + args = parser.parse_args() + + default_paths = [ + Path.cwd() / "dataset", + Path.cwd() / "Raw images" / "jpp21", + Path.cwd() / "jpp21_downloaded_data", + Path.cwd(), + ] + if args.root: + default_paths.insert(0, Path(args.root)) + + dataset = data.ensure_dataset(default_paths) + + app = QtWidgets.QApplication([]) + window = MainWindow(dataset) + window.resize(1400, 900) + window.show() + app.exec_() + + +if __name__ == "__main__": + main() diff --git a/viewer/preprocessing.py b/viewer/preprocessing.py new file mode 100644 index 00000000000..47e76d29d03 --- /dev/null +++ b/viewer/preprocessing.py @@ -0,0 +1,243 @@ +"""Preprocessing controls and algorithms.""" +from __future__ import annotations + +from dataclasses import asdict, dataclass +from typing import Dict, Tuple + +import numpy as np +from PyQt5 import QtCore, QtGui, QtWidgets +from matplotlib import colors as mpl_colors +from skimage import exposure, morphology + + +@dataclass +class PreprocessingSettings: + """Configuration for a single channel.""" + + radius: int = 10 + background_method: str = "median" + display_mode: str = "grayscale" # grayscale, inverted, false_color + false_color_min: Tuple[float, float, float] = (0.6, 0.4, 0.9) # HSV + false_color_max: Tuple[float, float, float] = (0.0, 1.0, 1.0) + clip_low: float = 2.0 + clip_high: float = 98.0 + + def to_dict(self) -> Dict[str, float]: + data = asdict(self) + data["false_color_min"] = list(self.false_color_min) + data["false_color_max"] = list(self.false_color_max) + return data + + @classmethod + def from_dict(cls, data: Dict[str, float]) -> "PreprocessingSettings": + if not data: + return cls() + values = dict(data) + if "false_color_min" in values: + values["false_color_min"] = tuple(values["false_color_min"]) + if "false_color_max" in values: + values["false_color_max"] = tuple(values["false_color_max"]) + return cls(**values) + + +class PreprocessingPanel(QtWidgets.QWidget): + """Qt widget exposing preprocessing controls.""" + + settingsChanged = QtCore.pyqtSignal(str, object) + applyRequested = QtCore.pyqtSignal(str) + batchRequested = QtCore.pyqtSignal(str) + saveRequested = QtCore.pyqtSignal(str) + + def __init__(self, parent: QtWidgets.QWidget | None = None) -> None: + super().__init__(parent) + self._channel = "dapi" + self._settings: Dict[str, PreprocessingSettings] = { + "dapi": PreprocessingSettings(), + "cy3": PreprocessingSettings(false_color_min=(0.15, 0.9, 0.95), false_color_max=(0.07, 1.0, 1.0)), + "cy5": PreprocessingSettings(false_color_min=(0.0, 1.0, 0.6), false_color_max=(0.0, 1.0, 1.0)), + "fitc": PreprocessingSettings(false_color_min=(0.33, 0.8, 0.7), false_color_max=(0.25, 1.0, 1.0)), + } + self._build_ui() + self._update_ui() + + # ------------------------------------------------------------------ + # UI helpers + def _build_ui(self) -> None: + layout = QtWidgets.QVBoxLayout(self) + + self.channel_label = QtWidgets.QLabel("Channel: dapi") + layout.addWidget(self.channel_label) + + self.radius_slider = QtWidgets.QSlider(QtCore.Qt.Horizontal) + self.radius_slider.setRange(0, 50) + self.radius_slider.setValue(10) + layout.addWidget(QtWidgets.QLabel("White top-hat radius")) + layout.addWidget(self.radius_slider) + + self.background_combo = QtWidgets.QComboBox() + self.background_combo.addItems(["median", "mode"]) + layout.addWidget(QtWidgets.QLabel("Background subtraction")) + layout.addWidget(self.background_combo) + + self.display_combo = QtWidgets.QComboBox() + self.display_combo.addItems(["grayscale", "inverted", "false_color"]) + layout.addWidget(QtWidgets.QLabel("Display mode")) + layout.addWidget(self.display_combo) + + self.clip_low_spin = QtWidgets.QDoubleSpinBox() + self.clip_low_spin.setRange(0.0, 50.0) + self.clip_low_spin.setSuffix(" %") + self.clip_low_spin.setValue(2.0) + self.clip_high_spin = QtWidgets.QDoubleSpinBox() + self.clip_high_spin.setRange(50.0, 100.0) + self.clip_high_spin.setSuffix(" %") + self.clip_high_spin.setValue(98.0) + clip_layout = QtWidgets.QFormLayout() + clip_layout.addRow("Clip low", self.clip_low_spin) + clip_layout.addRow("Clip high", self.clip_high_spin) + layout.addLayout(clip_layout) + + color_layout = QtWidgets.QHBoxLayout() + self.min_color_button = QtWidgets.QPushButton("Min colour") + self.max_color_button = QtWidgets.QPushButton("Max colour") + color_layout.addWidget(self.min_color_button) + color_layout.addWidget(self.max_color_button) + layout.addLayout(color_layout) + + button_layout = QtWidgets.QHBoxLayout() + self.apply_button = QtWidgets.QPushButton("Apply") + self.batch_button = QtWidgets.QPushButton("Batch") + self.save_button = QtWidgets.QPushButton("Save result") + button_layout.addWidget(self.apply_button) + button_layout.addWidget(self.batch_button) + button_layout.addWidget(self.save_button) + layout.addLayout(button_layout) + + layout.addStretch() + + # Signals + self.radius_slider.valueChanged.connect(self._emit_change) + self.background_combo.currentTextChanged.connect(self._emit_change) + self.display_combo.currentTextChanged.connect(self._emit_change) + self.clip_low_spin.valueChanged.connect(self._emit_change) + self.clip_high_spin.valueChanged.connect(self._emit_change) + self.min_color_button.clicked.connect(lambda: self._choose_color(True)) + self.max_color_button.clicked.connect(lambda: self._choose_color(False)) + self.apply_button.clicked.connect(lambda: self.applyRequested.emit(self._channel)) + self.batch_button.clicked.connect(lambda: self.batchRequested.emit(self._channel)) + self.save_button.clicked.connect(lambda: self.saveRequested.emit(self._channel)) + + def _choose_color(self, is_min: bool) -> None: + settings = self._settings[self._channel] + hsv = settings.false_color_min if is_min else settings.false_color_max + qcolor = QtGui.QColor() + qcolor.setHsvF(*hsv) + chosen = QtWidgets.QColorDialog.getColor(qcolor, self) + if chosen.isValid(): + hsv = (chosen.hueF(), chosen.saturationF(), chosen.valueF()) + if is_min: + settings.false_color_min = hsv + else: + settings.false_color_max = hsv + self._emit_change() + + def _update_ui(self) -> None: + settings = self._settings[self._channel] + self.channel_label.setText(f"Channel: {self._channel}") + self.radius_slider.blockSignals(True) + self.background_combo.blockSignals(True) + self.display_combo.blockSignals(True) + self.clip_low_spin.blockSignals(True) + self.clip_high_spin.blockSignals(True) + self.radius_slider.setValue(settings.radius) + self.background_combo.setCurrentText(settings.background_method) + self.display_combo.setCurrentText(settings.display_mode) + self.clip_low_spin.setValue(settings.clip_low) + self.clip_high_spin.setValue(settings.clip_high) + self.radius_slider.blockSignals(False) + self.background_combo.blockSignals(False) + self.display_combo.blockSignals(False) + self.clip_low_spin.blockSignals(False) + self.clip_high_spin.blockSignals(False) + + def set_channel(self, name: str, data: Dict[str, Dict]) -> None: + self._channel = name + if name not in self._settings: + self._settings[name] = PreprocessingSettings() + if data and name in data: + self._settings[name] = PreprocessingSettings.from_dict(data[name]) + self._update_ui() + + def export_settings(self) -> Dict[str, Dict]: + self._emit_change() + return {ch: cfg.to_dict() for ch, cfg in self._settings.items()} + + def settings_for(self, channel: str) -> PreprocessingSettings: + return self._settings.setdefault(channel, PreprocessingSettings()) + + def _emit_change(self) -> None: + settings = self._settings[self._channel] + settings.radius = self.radius_slider.value() + settings.background_method = self.background_combo.currentText() + settings.display_mode = self.display_combo.currentText() + settings.clip_low = self.clip_low_spin.value() + settings.clip_high = self.clip_high_spin.value() + if settings.clip_low >= settings.clip_high: + settings.clip_high = min(100.0, settings.clip_low + 1.0) + self.clip_high_spin.blockSignals(True) + self.clip_high_spin.setValue(settings.clip_high) + self.clip_high_spin.blockSignals(False) + self.settingsChanged.emit(self._channel, settings) + + +# --------------------------------------------------------------------------- +# Processing functions + + +def apply_white_tophat(image: np.ndarray, radius: int) -> np.ndarray: + if radius <= 0: + return image + selem = morphology.disk(radius) + return morphology.white_tophat(image, selem) + + +def subtract_background(image: np.ndarray, method: str) -> np.ndarray: + if method == "mode": + hist, bin_edges = np.histogram(image, bins=512) + mode_index = int(np.argmax(hist)) + background = (bin_edges[mode_index] + bin_edges[mode_index + 1]) / 2 + else: + background = float(np.median(image)) + corrected = image - background + corrected[corrected < 0] = 0 + return corrected + + +def stretch_contrast(image: np.ndarray, low: float, high: float) -> np.ndarray: + low_val, high_val = np.percentile(image, (low, high)) + if np.isclose(low_val, high_val): + return np.zeros_like(image) + return exposure.rescale_intensity(image, in_range=(low_val, high_val)) + + +def preprocess_image(image: np.ndarray, settings: PreprocessingSettings) -> np.ndarray: + arr = image.astype(np.float32) + arr = apply_white_tophat(arr, settings.radius) + arr = subtract_background(arr, settings.background_method) + arr = stretch_contrast(arr, settings.clip_low, settings.clip_high) + max_val = float(arr.max()) if arr.size else 0.0 + if max_val > 0: + arr = arr / max_val + return arr + + +def map_false_colour(image: np.ndarray, settings: PreprocessingSettings) -> np.ndarray: + arr = np.clip(image, 0, 1) + h_low, s_low, v_low = settings.false_color_min + h_high, s_high, v_high = settings.false_color_max + h = h_low + (h_high - h_low) * arr + s = s_low + (s_high - s_low) * arr + v = v_low + (v_high - v_low) * arr + hsv = np.stack([h, s, v], axis=-1) + rgb = mpl_colors.hsv_to_rgb(hsv) + return rgb diff --git a/viewer/segmentation.py b/viewer/segmentation.py new file mode 100644 index 00000000000..a99fd5c6444 --- /dev/null +++ b/viewer/segmentation.py @@ -0,0 +1,228 @@ +"""Segmentation controls and utilities.""" +from __future__ import annotations + +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +import cv2 +import numpy as np +from PyQt5 import QtCore, QtWidgets +from skimage import measure, morphology + + +@dataclass +class SegmentationSettings: + block_size: int = 35 + offset: int = -10 + display_mode: str = "overlay" # overlay, contour, filled + brush_radius: int = 8 + min_size: int = 50 + max_size: int = 10000 + borderkill: bool = False + + def to_dict(self) -> Dict: + return asdict(self) + + @classmethod + def from_dict(cls, data: Optional[Dict]) -> "SegmentationSettings": + if not data: + return cls() + return cls(**data) + + +class SegmentationPanel(QtWidgets.QWidget): + generateRequested = QtCore.pyqtSignal(str, object) + saveRequested = QtCore.pyqtSignal(str) + displayModeChanged = QtCore.pyqtSignal(str) + brushChanged = QtCore.pyqtSignal(int, str) + undoRequested = QtCore.pyqtSignal() + redoRequested = QtCore.pyqtSignal() + filterRequested = QtCore.pyqtSignal() + + def __init__(self, parent: QtWidgets.QWidget | None = None) -> None: + super().__init__(parent) + self._channel = "dapi" + self._per_channel: Dict[str, SegmentationSettings] = { + "dapi": SegmentationSettings(), + } + self._settings = self._per_channel[self._channel] + self._build_ui() + + def _build_ui(self) -> None: + layout = QtWidgets.QVBoxLayout(self) + + self.channel_label = QtWidgets.QLabel("Channel: dapi") + layout.addWidget(self.channel_label) + + self.block_spin = QtWidgets.QSpinBox() + self.block_spin.setRange(3, 255) + self.block_spin.setSingleStep(2) + self.block_spin.setValue(self._settings.block_size) + layout.addWidget(QtWidgets.QLabel("Adaptive block size")) + layout.addWidget(self.block_spin) + + self.offset_spin = QtWidgets.QSpinBox() + self.offset_spin.setRange(-50, 50) + self.offset_spin.setValue(self._settings.offset) + layout.addWidget(QtWidgets.QLabel("Threshold offset")) + layout.addWidget(self.offset_spin) + + self.display_combo = QtWidgets.QComboBox() + self.display_combo.addItems(["overlay", "contour", "filled"]) + layout.addWidget(QtWidgets.QLabel("Display mode")) + layout.addWidget(self.display_combo) + + self.brush_combo = QtWidgets.QComboBox() + self.brush_combo.addItems(["brush", "eraser"]) + self.brush_radius = QtWidgets.QSpinBox() + self.brush_radius.setRange(1, 50) + self.brush_radius.setValue(self._settings.brush_radius) + brush_layout = QtWidgets.QFormLayout() + brush_layout.addRow("Tool", self.brush_combo) + brush_layout.addRow("Radius", self.brush_radius) + layout.addLayout(brush_layout) + + self.min_size_spin = QtWidgets.QSpinBox() + self.min_size_spin.setRange(0, 50000) + self.min_size_spin.setValue(self._settings.min_size) + self.max_size_spin = QtWidgets.QSpinBox() + self.max_size_spin.setRange(0, 100000) + self.max_size_spin.setValue(self._settings.max_size) + self.borderkill_check = QtWidgets.QCheckBox("Remove border-touching objects") + self.borderkill_check.setChecked(self._settings.borderkill) + size_layout = QtWidgets.QFormLayout() + size_layout.addRow("Min size", self.min_size_spin) + size_layout.addRow("Max size", self.max_size_spin) + layout.addLayout(size_layout) + layout.addWidget(self.borderkill_check) + + button_layout = QtWidgets.QHBoxLayout() + self.generate_button = QtWidgets.QPushButton("Preview") + self.filter_button = QtWidgets.QPushButton("Filter") + self.save_button = QtWidgets.QPushButton("Save mask") + button_layout.addWidget(self.generate_button) + button_layout.addWidget(self.filter_button) + button_layout.addWidget(self.save_button) + layout.addLayout(button_layout) + + undo_layout = QtWidgets.QHBoxLayout() + self.undo_button = QtWidgets.QPushButton("Undo") + self.redo_button = QtWidgets.QPushButton("Redo") + undo_layout.addWidget(self.undo_button) + undo_layout.addWidget(self.redo_button) + layout.addLayout(undo_layout) + + layout.addStretch() + + # Connections + self.generate_button.clicked.connect(self._emit_generate) + self.filter_button.clicked.connect(self._emit_filter) + self.save_button.clicked.connect(lambda: self.saveRequested.emit(self._channel)) + self.display_combo.currentTextChanged.connect(self.displayModeChanged) + self.brush_radius.valueChanged.connect(self._emit_brush) + self.brush_combo.currentTextChanged.connect(self._emit_brush) + self.undo_button.clicked.connect(self.undoRequested) + self.redo_button.clicked.connect(self.redoRequested) + + def _emit_generate(self) -> None: + self._sync_settings() + self.generateRequested.emit(self._channel, self._settings) + + def _emit_brush(self) -> None: + self._sync_settings() + self.brushChanged.emit(self._settings.brush_radius, self.brush_combo.currentText()) + + def _emit_filter(self) -> None: + self._sync_settings() + self.filterRequested.emit() + + def _sync_settings(self) -> None: + self._settings.block_size = int(self.block_spin.value()) | 1 + self._settings.offset = int(self.offset_spin.value()) + self._settings.display_mode = self.display_combo.currentText() + self._settings.brush_radius = int(self.brush_radius.value()) + self._settings.min_size = int(self.min_size_spin.value()) + self._settings.max_size = int(self.max_size_spin.value()) + self._settings.borderkill = self.borderkill_check.isChecked() + self._per_channel[self._channel] = self._settings + + def set_channel(self, name: str, data: Optional[Dict]) -> None: + self._channel = name + self.channel_label.setText(f"Channel: {name}") + if name not in self._per_channel: + self._per_channel[name] = SegmentationSettings() + if data and name in data: + self._per_channel[name] = SegmentationSettings.from_dict(data[name]) + self._settings = self._per_channel[name] + self.block_spin.setValue(self._settings.block_size) + self.offset_spin.setValue(self._settings.offset) + self.display_combo.setCurrentText(self._settings.display_mode) + self.brush_radius.setValue(self._settings.brush_radius) + self.min_size_spin.setValue(self._settings.min_size) + self.max_size_spin.setValue(self._settings.max_size) + self.borderkill_check.setChecked(self._settings.borderkill) + + def export_settings(self) -> Dict[str, Dict]: + self._sync_settings() + return {name: settings.to_dict() for name, settings in self._per_channel.items()} + + def settings_for(self, channel: str) -> SegmentationSettings: + if channel not in self._per_channel: + self._per_channel[channel] = SegmentationSettings() + if channel == self._channel: + self._sync_settings() + return self._per_channel[channel] + + +# --------------------------------------------------------------------------- +# Algorithms + + +def adaptive_threshold(image: np.ndarray, settings: SegmentationSettings) -> np.ndarray: + arr = image.astype(np.uint8) + if arr.max() > 255: + arr = cv2.normalize(arr, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) + block_size = max(3, settings.block_size | 1) + thresh = cv2.adaptiveThreshold( + arr, + 255, + cv2.ADAPTIVE_THRESH_MEAN_C, + cv2.THRESH_BINARY, + block_size, + settings.offset, + ) + mask = (thresh > 0).astype(np.uint8) + return mask + + +def apply_filters(mask: np.ndarray, settings: SegmentationSettings) -> np.ndarray: + filtered = mask.astype(bool) + if settings.min_size > 0: + filtered = morphology.remove_small_objects(filtered, min_size=settings.min_size) + if settings.max_size > 0: + labeled = measure.label(filtered) + props = measure.regionprops(labeled) + allowed = np.zeros_like(filtered) + for prop in props: + if prop.area <= settings.max_size: + allowed[labeled == prop.label] = True + filtered = allowed + if settings.borderkill: + filtered &= ~_touches_border(filtered) + return filtered.astype(np.uint8) + + +def _touches_border(mask: np.ndarray) -> np.ndarray: + border = np.zeros_like(mask, dtype=bool) + border[0, :] = True + border[-1, :] = True + border[:, 0] = True + border[:, -1] = True + border_pixels = mask & border + labeled = measure.label(mask) + kill = np.zeros_like(mask, dtype=bool) + for label in np.unique(labeled[border_pixels]): + if label == 0: + continue + kill[labeled == label] = True + return kill diff --git a/viewer/utils/__init__.py b/viewer/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/viewer/utils/data.py b/viewer/utils/data.py new file mode 100644 index 00000000000..3f862f7fce0 --- /dev/null +++ b/viewer/utils/data.py @@ -0,0 +1,111 @@ +"""Utility helpers for discovering cytogenetic image datasets.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, Iterable, List, Optional + +import numpy as np +from skimage import io + + +CHANNEL_ORDER = ("dapi", "fitc", "cy3", "cy5") +IMAGE_EXTENSIONS = {".tif", ".tiff"} + + +@dataclass +class Metaphase: + """Container describing a metaphase and its channel image files.""" + + name: str + path: Path + channels: Dict[str, List[Path]] = field(default_factory=dict) + + def available_channels(self) -> List[str]: + """Return the list of available channel names ordered for display.""" + ordered = [c for c in CHANNEL_ORDER if c in self.channels] + tail = sorted(set(self.channels) - set(ordered)) + return ordered + tail + + +@dataclass +class Slide: + """Collection of metaphases belonging to a slide.""" + + name: str + path: Path + metaphases: Dict[str, Metaphase] = field(default_factory=dict) + + def available_metaphases(self) -> List[str]: + return sorted(self.metaphases) + + +@dataclass +class Dataset: + """Description of all slides discovered below a root folder.""" + + root: Path + slides: Dict[str, Slide] = field(default_factory=dict) + + def available_slides(self) -> List[str]: + return sorted(self.slides) + + +def _collect_channel_files(channel_dir: Path) -> List[Path]: + files = [p for p in channel_dir.iterdir() if p.suffix.lower() in IMAGE_EXTENSIONS] + files.sort() + return files + + +def discover_dataset(root: Path) -> Optional[Dataset]: + """Return a :class:`Dataset` from *root* or ``None`` if nothing is found.""" + root = root.expanduser().resolve() + if not root.exists(): + return None + + slides: Dict[str, Slide] = {} + for slide_dir in sorted(p for p in root.iterdir() if p.is_dir()): + metaphases: Dict[str, Metaphase] = {} + for metaphase_dir in sorted(p for p in slide_dir.iterdir() if p.is_dir()): + channels: Dict[str, List[Path]] = {} + for channel_dir in sorted(p for p in metaphase_dir.iterdir() if p.is_dir()): + files = _collect_channel_files(channel_dir) + if files: + channels[channel_dir.name.lower()] = files + if channels: + metaphases[metaphase_dir.name] = Metaphase( + name=metaphase_dir.name, + path=metaphase_dir, + channels=channels, + ) + if metaphases: + slides[slide_dir.name] = Slide( + name=slide_dir.name, + path=slide_dir, + metaphases=metaphases, + ) + if not slides: + return None + return Dataset(root=root, slides=slides) + + +def load_image(path: Path) -> np.ndarray: + """Load an image as ``float32`` for further processing.""" + arr = io.imread(str(path)).astype(np.float32) + return arr + + +def ensure_dataset(paths: Iterable[Path]) -> Dataset: + """Return the first dataset found in *paths* or create a synthetic one.""" + for candidate in paths: + dataset = discover_dataset(candidate) + if dataset is not None: + return dataset + + from .fake_data import create_synthetic_dataset + + synthetic_root = create_synthetic_dataset() + dataset = discover_dataset(synthetic_root) + if dataset is None: + raise RuntimeError("Failed to create synthetic dataset") + return dataset diff --git a/viewer/utils/fake_data.py b/viewer/utils/fake_data.py new file mode 100644 index 00000000000..f95dd80cb55 --- /dev/null +++ b/viewer/utils/fake_data.py @@ -0,0 +1,55 @@ +"""Synthetic dataset generation for development environments.""" +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from typing import Tuple + +import numpy as np +from skimage import draw, io, filters + + +def _make_channel(seed: int, shape: Tuple[int, int]) -> np.ndarray: + rng = np.random.default_rng(seed) + base = rng.normal(loc=2000, scale=300, size=shape).astype(np.float32) + # Add a few Gaussian blobs to mimic chromosomes. + for _ in range(8): + rr, cc = draw.disk( + center=(rng.integers(20, shape[0] - 20), rng.integers(20, shape[1] - 20)), + radius=rng.integers(8, 18), + shape=shape, + ) + base[rr, cc] += rng.uniform(1200, 2500) + base = filters.gaussian(base, sigma=1.5) + base -= base.min() + base /= base.max() + return (base * 4095).astype(np.uint16) + + +def create_synthetic_dataset() -> Path: + """Create a temporary dataset compatible with the viewer layout.""" + root = Path(tempfile.mkdtemp(prefix="cyto-viewer-demo-")) + slide_dir = root / "demo-slide" + metaphase_dir = slide_dir / "1" + metaphase_dir.mkdir(parents=True, exist_ok=True) + + channels = { + "dapi": _make_channel(1, (256, 256)), + "cy3": _make_channel(2, (256, 256)), + "cy5": _make_channel(3, (256, 256)), + } + + for name, arr in channels.items(): + channel_dir = metaphase_dir / name + channel_dir.mkdir() + io.imsave(str(channel_dir / "1.tif"), arr) + + # Default config file showcasing the JSON format. + config_path = slide_dir / "config.json" + config_path.write_text(json.dumps({ + "preprocessing": {}, + "alignment": {}, + "segmentation": {}, + }, indent=2)) + return root