diff --git a/pyidi/GUIs/__init__.py b/pyidi/GUIs/__init__.py new file mode 100644 index 0000000..4052991 --- /dev/null +++ b/pyidi/GUIs/__init__.py @@ -0,0 +1,29 @@ +import typing + +try: + import PyQt6 + + HAS_PYQT6 = True +except ImportError: + HAS_PYQT6 = False + +if HAS_PYQT6 or typing.TYPE_CHECKING: + from .subset_selection import SelectionGUI + from .result_viewer import ResultViewer +else: + class SelectionGUI: + def __init__(self, video): + pass + + def show_displacement(self, data): + raise RuntimeError("SelectionGUI requires PyQt6: pip install pyidi[qt]") + + class ResultViewer: + def __init__(self): + pass + + def show_displacement(self, data): + raise RuntimeError("ResultViewer requires PyQt6: pip install pyidi[qt]") + +from .selection import SubsetSelection +from .gui import GUI \ No newline at end of file diff --git a/pyidi/gui.py b/pyidi/GUIs/gui.py similarity index 99% rename from pyidi/gui.py rename to pyidi/GUIs/gui.py index 6313eb6..121f65b 100644 --- a/pyidi/gui.py +++ b/pyidi/GUIs/gui.py @@ -6,10 +6,10 @@ import warnings warnings.simplefilter("default") -from . import tools +from .. import tools from . import selection -from .methods import SimplifiedOpticalFlow -from .methods import LucasKanade +from ..methods import SimplifiedOpticalFlow +from ..methods import LucasKanade NO_METHOD = '---' add_vertical_stretch = True diff --git a/pyidi/GUIs/result_viewer.py b/pyidi/GUIs/result_viewer.py new file mode 100644 index 0000000..2e57f34 --- /dev/null +++ b/pyidi/GUIs/result_viewer.py @@ -0,0 +1,1021 @@ +import numpy as np +import pyqtgraph as pg +from PyQt6 import QtWidgets, QtCore, QtGui +import matplotlib.pyplot as plt +import matplotlib.cm as cm +import matplotlib.colors as mcolors +import sys + +class RegionSelectViewBox(pg.ViewBox): + """Custom ViewBox that handles region selection with mouse events.""" + + def __init__(self, parent_viewer): + super().__init__() + self.parent_viewer = parent_viewer + self.region_start = None + self.region_current = None + self.dragging = False + + def mousePressEvent(self, ev): + if (self.parent_viewer.region_selection_active and + ev.button() == QtCore.Qt.MouseButton.LeftButton and + ev.modifiers() & QtCore.Qt.KeyboardModifier.ControlModifier): + # Start region selection + self.region_start = self.mapSceneToView(ev.scenePos()) + self.dragging = True + ev.accept() + else: + super().mousePressEvent(ev) + + def mouseMoveEvent(self, ev): + if self.dragging and self.parent_viewer.region_selection_active: + # Update region selection + self.region_current = self.mapSceneToView(ev.scenePos()) + self.parent_viewer.update_region_selection(self.region_start, self.region_current) + ev.accept() + else: + super().mouseMoveEvent(ev) + + def mouseReleaseEvent(self, ev): + if (self.dragging and self.parent_viewer.region_selection_active and + ev.button() == QtCore.Qt.MouseButton.LeftButton): + # Finish region selection + self.region_current = self.mapSceneToView(ev.scenePos()) + self.parent_viewer.finish_region_selection(self.region_start, self.region_current) + self.dragging = False + ev.accept() + else: + super().mouseReleaseEvent(ev) + +class ResultViewer(QtWidgets.QMainWindow): + def __init__(self, video, displacements, points, fps=30, magnification=1, point_size=10, colormap="cool"): + """ + The results from the pyidi analysis can directly be passed to this class: + + - ``video``: can be a ``VideoReader`` object (or numpy array of correct shape). + - ``displacements``: directly the return from the ``get_displacements`` method or mode shapes. + - ``points``: the points used for the analysis, which were passed to the ``set_points`` method. + + Parameters + ---------- + video : np.ndarray or VideoReader + Array of shape (n_frames, height, width) containing the video frames. + displacements : np.ndarray + Array of shape (n_frames, n_points, 2) for time-series displacements OR + Array of shape (n_points, 2) for mode shapes. + points : np.ndarray + Array of shape (n_points, 2) containing the grid points. + fps : int, optional + Frames per second for the video playback, by default 30. + magnification : int, optional + Magnification factor for the displacements, by default 1. + point_size : int, optional + Size of the points in pixels, by default 10. + colormap : str, optional + Name of the colormap to use for the arrows, by default "cool". + """ + # Create QApplication if it doesn't exist + app = QtWidgets.QApplication.instance() + if app is None: + app = QtWidgets.QApplication([]) + + super().__init__() + + # Coordinate transformation to match viewer function behavior + from ..video_reader import VideoReader + if isinstance(video, VideoReader): + self.video = video.get_frames() + else: + self.video = video + + # Check if displacements are 2D (mode shapes) or 3D (time-series) + if displacements.ndim == 2: + # Mode shapes: shape (n_points, 2) + self.is_mode_shape = True + self.displacements = displacements[:, ::-1] # Flip x,y coordinates + self.time_per_period = 1.0 # Seconds + else: + # Time-series displacements: shape (n_frames, n_points, 2) + self.is_mode_shape = False + self.displacements = displacements[:, :, ::-1] # Flip x,y coordinates + + self.grid = points[:, ::-1] + 0.5 # Flip x,y coordinates + self.fps = fps + self.magnification = magnification + self.points_size = point_size + self.current_frame = 0 + + self.disp_max = np.max(np.abs(displacements)) + self.colormap = colormap + + # Region selection variables + self.region_selection_active = False + self.region_start_point = None + self.region_end_point = None + self.region_rect = None + self.region_overlay = None + self.selected_region = None # (x, y, width, height) in image coordinates + + self.timer = QtCore.QTimer() + self.timer.timeout.connect(self.next_frame) + + self.init_ui() + self.update_frame() + + # Start the GUI + self.show() + # Only call sys.exit if not in IPython + if not hasattr(sys, 'ps1'): # Not interactive + sys.exit(app.exec()) + else: + app.exec() # Don't raise SystemExit in IPython + + def init_ui(self): + # Style + self.setStyleSheet(""" + QTabWidget::pane { border: 0; } + QPushButton { + background-color: #444; + color: white; + padding: 6px 12px; + border: 1px solid #555; + border-radius: 4px; + } + QPushButton:checked { + background-color: #0078d7; + border: 1px solid #005bb5; + } + QGroupBox { + font-weight: bold; + border: 2px solid #555; + border-radius: 8px; + margin-top: 15px; + padding-top: 15px; + background-color: #3a3a3a; + color: white; + } + QGroupBox::title { + subcontrol-origin: margin; + subcontrol-position: top left; + left: 15px; + top: 4px; + padding: 2px 10px; + color: #e0e0e0; + background-color: #4a4a4a; + border: 1px solid #666; + border-radius: 4px; + font-size: 11px; + font-weight: bold; + } + """) + + central_widget = QtWidgets.QWidget() + main_layout = QtWidgets.QVBoxLayout(central_widget) + main_layout.setContentsMargins(0, 0, 0, 0) + main_layout.setSpacing(0) + + # Add splitter + self.splitter = QtWidgets.QSplitter(QtCore.Qt.Orientation.Horizontal) + main_layout.addWidget(self.splitter, stretch=1) + + # === Video Display === + self.view = pg.GraphicsLayoutWidget() + self.img_item = pg.ImageItem() + self.scatter = pg.ScatterPlotItem(size=self.points_size, brush='r', pxMode=True) + + # Create custom viewbox for region selection + self.viewbox = RegionSelectViewBox(self) + self.view.addItem(self.viewbox) + + self.viewbox.addItem(self.img_item) + self.viewbox.addItem(self.scatter) + self.viewbox.setAspectLocked(True) + self.viewbox.invertY(True) + self.arrow_shafts = [] + self.splitter.addWidget(self.view) + + # === Right Control Panel === + self.control_widget = QtWidgets.QWidget() + self.control_layout = QtWidgets.QVBoxLayout(self.control_widget) + + # Display controls group + display_group = QtWidgets.QGroupBox("Display Controls") + display_layout = QtWidgets.QVBoxLayout(display_group) + + # Point size control + point_size_layout = QtWidgets.QHBoxLayout() + point_size_layout.addWidget(QtWidgets.QLabel("Point size:")) + + self.point_size_spin = QtWidgets.QSpinBox() + self.point_size_spin.setRange(1, 100) + self.point_size_spin.setValue(self.points_size) + self.point_size_spin.setSuffix("px") + self.point_size_spin.setFixedWidth(80) + self.point_size_spin.valueChanged.connect(self.update_point_size_from_spinbox) + point_size_layout.addWidget(self.point_size_spin) + + point_size_layout.addStretch() # Push everything to the left + display_layout.addLayout(point_size_layout) + + self.point_size_slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + self.point_size_slider.setRange(1, 20) + self.point_size_slider.setValue(min(20, self.points_size)) + self.point_size_slider.valueChanged.connect(self.update_point_size_from_slider) + display_layout.addWidget(self.point_size_slider) + + # Magnification control + mag_layout = QtWidgets.QHBoxLayout() + mag_layout.addWidget(QtWidgets.QLabel("Magnify:")) + + self.mag_spin = QtWidgets.QDoubleSpinBox() + self.mag_spin.setRange(0.01, 999999) # No practical upper limit + self.mag_spin.setSingleStep(0.01) + self.mag_spin.setValue(self.magnification) + self.mag_spin.setSuffix("x") + self.mag_spin.setFixedWidth(80) + self.mag_spin.valueChanged.connect(self.update_mag_from_spinbox) + mag_layout.addWidget(self.mag_spin) + + mag_layout.addStretch() # Push everything to the left + display_layout.addLayout(mag_layout) + + self.mag_slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + self.mag_slider.setRange(1, 1000) # 0.1x to 10x (in percent: 10% to 1000%) + self.mag_slider.setValue(int(self.magnification * 100)) + self.mag_slider.valueChanged.connect(self.update_mag_from_slider) + display_layout.addWidget(self.mag_slider) + + # Show arrows checkbox + self.arrows_checkbox = QtWidgets.QCheckBox("Show arrows") + self.arrows_checkbox.stateChanged.connect(self.update_frame) + display_layout.addWidget(self.arrows_checkbox) + + self.control_layout.addWidget(display_group) + + # Playback controls group + playback_group = QtWidgets.QGroupBox("Playback Controls") + playback_layout = QtWidgets.QVBoxLayout(playback_group) + + # FPS control + fps_layout = QtWidgets.QHBoxLayout() + fps_layout.addWidget(QtWidgets.QLabel("FPS:")) + + self.fps_spin = QtWidgets.QSpinBox() + self.fps_spin.setRange(1, 240) + self.fps_spin.setValue(self.fps) + self.fps_spin.setFixedWidth(80) + self.fps_spin.valueChanged.connect(self.update_fps_from_spinbox) + fps_layout.addWidget(self.fps_spin) + + fps_layout.addStretch() # Push everything to the left + playback_layout.addLayout(fps_layout) + + self.fps_slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + self.fps_slider.setRange(1, 240) + self.fps_slider.setValue(self.fps) + self.fps_slider.valueChanged.connect(self.update_fps_from_slider) + playback_layout.addWidget(self.fps_slider) + + self.control_layout.addWidget(playback_group) + + # Export controls group + export_group = QtWidgets.QGroupBox("Export Video") + export_layout = QtWidgets.QVBoxLayout(export_group) + + # Quality/FPS for export + export_layout.addWidget(QtWidgets.QLabel("Export FPS:")) + self.export_fps_spin = QtWidgets.QSpinBox() + self.export_fps_spin.setRange(1, 120) + self.export_fps_spin.setValue(30) + export_layout.addWidget(self.export_fps_spin) + + # Export resolution + export_layout.addWidget(QtWidgets.QLabel("Export Resolution:")) + self.export_resolution_combo = QtWidgets.QComboBox() + self.export_resolution_combo.addItems([ + "2x pixel scale", + "4x pixel scale", + "6x pixel scale", + "8x pixel scale", + ]) + self.export_resolution_combo.setCurrentText("4x pixel scale") + export_layout.addWidget(self.export_resolution_combo) + + # Region selection controls + export_layout.addWidget(QtWidgets.QLabel("Region Selection:")) + + region_layout = QtWidgets.QHBoxLayout() + + # Region selection button + self.region_select_button = QtWidgets.QPushButton("Select Region") + self.region_select_button.setCheckable(True) + self.region_select_button.clicked.connect(self.toggle_region_selection) + region_layout.addWidget(self.region_select_button) + + # Clear region button + self.clear_region_button = QtWidgets.QPushButton("Clear") + self.clear_region_button.clicked.connect(self.clear_region_selection) + self.clear_region_button.setEnabled(False) + region_layout.addWidget(self.clear_region_button) + + export_layout.addLayout(region_layout) + + # Region info label + self.region_info_label = QtWidgets.QLabel("Full frame will be exported") + self.region_info_label.setStyleSheet("font-size: 10px; color: #aaa;") + export_layout.addWidget(self.region_info_label) + + # Frame range controls (only for non-mode shape videos) + if not self.is_mode_shape: + self.frame_range_label = QtWidgets.QLabel("Frame Range:") + export_layout.addWidget(self.frame_range_label) + + frame_range_layout = QtWidgets.QHBoxLayout() + + # Start frame + self.start_frame_spin = QtWidgets.QSpinBox() + self.start_frame_spin.setRange(0, self.video.shape[0] - 1) + self.start_frame_spin.setValue(0) + self.start_frame_spin.setFixedWidth(80) + self.start_frame_spin.valueChanged.connect(self.on_start_frame_changed) + frame_range_layout.addWidget(self.start_frame_spin) + + # Stop frame + self.stop_frame_spin = QtWidgets.QSpinBox() + self.stop_frame_spin.setRange(0, self.video.shape[0] - 1) + self.stop_frame_spin.setValue(self.video.shape[0] - 1) + self.stop_frame_spin.setFixedWidth(80) + self.stop_frame_spin.valueChanged.connect(self.on_stop_frame_changed) + frame_range_layout.addWidget(self.stop_frame_spin) + + export_layout.addLayout(frame_range_layout) + + # Update the label with initial frame count + self.update_frame_range_label() + + # Full range checkbox + self.full_range_checkbox = QtWidgets.QCheckBox("Full Range") + self.full_range_checkbox.setChecked(True) # Initially checked since defaults are full range + self.full_range_checkbox.stateChanged.connect(self.on_full_range_checkbox_changed) + export_layout.addWidget(self.full_range_checkbox) + + # Duration for mode shapes + if self.is_mode_shape: + export_layout.addWidget(QtWidgets.QLabel("Duration (seconds):")) + self.duration_spin = QtWidgets.QDoubleSpinBox() + self.duration_spin.setRange(0.5, 60.0) + self.duration_spin.setValue(2.0) + self.duration_spin.setSingleStep(0.5) + export_layout.addWidget(self.duration_spin) + + # Export button + self.export_button = QtWidgets.QPushButton("Export Video") + self.export_button.clicked.connect(self.export_video) + export_layout.addWidget(self.export_button) + + # Progress bar + self.export_progress = QtWidgets.QProgressBar() + self.export_progress.setVisible(False) + export_layout.addWidget(self.export_progress) + + self.control_layout.addWidget(export_group) + + self.control_layout.addStretch(1) + + self.splitter.addWidget(self.control_widget) + + # Set splitter proportions + self.splitter.setStretchFactor(0, 5) # Video area grows more + self.splitter.setStretchFactor(1, 0) # Controls panel fixed by content + + # Set initial width for right panel + self.control_widget.setMinimumWidth(150) + self.control_widget.setMaximumWidth(600) + self.splitter.setSizes([800, 200]) # Initial left/right width + + # === Bottom Playback Controls === + playback_layout = QtWidgets.QHBoxLayout() + playback_layout.setContentsMargins(5, 5, 5, 5) + + self.play_button = QtWidgets.QPushButton("▶️") + self.play_button.clicked.connect(self.toggle_playback) + playback_layout.addWidget(self.play_button) + + + playback_layout.addWidget(QtWidgets.QLabel(" Frame: ")) + self.slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + if self.is_mode_shape: + self.slider.setRange(0, int(self.fps * self.time_per_period) - 1) + else: + self.slider.setRange(0, self.video.shape[0] - 1) + self.slider.setValue(0) + self.slider.valueChanged.connect(self.on_slider) + playback_layout.addWidget(self.slider) + + self.frame_spinbox = QtWidgets.QSpinBox() + if self.is_mode_shape: + self.frame_spinbox.setRange(0, int(self.fps * self.time_per_period) - 1) + else: + self.frame_spinbox.setRange(0, self.video.shape[0] - 1) + + self.frame_spinbox.valueChanged.connect(self.on_slider) + self.frame_spinbox.setValue(0) + playback_layout.addWidget(self.frame_spinbox) + + main_layout.addLayout(playback_layout) + + # === Finalize === + self.setCentralWidget(central_widget) + self.setWindowTitle("Displacement Viewer") + self.resize(1200, 600) + + def toggle_playback(self): + if self.timer.isActive(): + self.timer.stop() + self.play_button.setText("▶️") + else: + self.set_timer_interval() + self.timer.start() + self.play_button.setText("⏹️") + + def set_timer_interval(self): + self.fps = self.fps_spin.value() + interval_ms = int(1000 / self.fps) + self.timer.setInterval(interval_ms) + + def on_fps_change(self): + if self.timer.isActive(): + self.set_timer_interval() + + if self.is_mode_shape: + self.slider.setMaximum(int(self.fps * self.time_per_period) - 1) + self.frame_spinbox.setMaximum(int(self.fps * self.time_per_period) - 1) + + def update_fps_from_slider(self, value): + self.fps = value + self.fps_spin.blockSignals(True) # Prevent recursive updates + self.fps_spin.setValue(value) + self.fps_spin.blockSignals(False) + self.on_fps_change() + + def update_fps_from_spinbox(self, value): + self.fps = value + self.fps_slider.blockSignals(True) # Prevent recursive updates + self.fps_slider.setValue(value) + self.fps_slider.blockSignals(False) + self.on_fps_change() + + def update_mag_from_slider(self, value): + # Convert slider value (1-1000) to magnification (0.01-10.0) + magnification = value / 100.0 + self.magnification = magnification + + # Update spinbox without triggering its signal + self.mag_spin.blockSignals(True) + self.mag_spin.setValue(magnification) + self.mag_spin.blockSignals(False) + + self.update_frame() + + def update_mag_from_spinbox(self, value): + # Convert spinbox value to magnification + magnification = value + self.magnification = magnification + + # Update slider, clamping to its range and converting to int + slider_value = int(max(1, min(1000, value * 100))) + self.mag_slider.blockSignals(True) + self.mag_slider.setValue(slider_value) + self.mag_slider.blockSignals(False) + + self.update_frame() + + def update_point_size_from_slider(self, value): + # Update the internal point size + self.points_size = value + + # Update spinbox without triggering its signal + self.point_size_spin.blockSignals(True) + self.point_size_spin.setValue(value) + self.point_size_spin.blockSignals(False) + + # Update the actual display + self.scatter.setSize(value) + + def update_point_size_from_spinbox(self, value): + # Update the internal point size + self.points_size = value + + # Update slider, clamping to its range + slider_value = min(20, max(1, value)) + self.point_size_slider.blockSignals(True) + self.point_size_slider.setValue(slider_value) + self.point_size_slider.blockSignals(False) + + # Update the actual display + self.scatter.setSize(value) + + def update_point_size(self): + # Keep this method for backward compatibility if needed + size = self.point_size_spin.value() + self.scatter.setSize(size) + + def on_start_frame_changed(self, value): + # Ensure start frame is not greater than stop frame + if hasattr(self, 'stop_frame_spin') and value > self.stop_frame_spin.value(): + self.stop_frame_spin.setValue(value) + + # Update the frame range label + self.update_frame_range_label() + + # Update checkbox state based on whether we have full range + self.update_full_range_checkbox_state() + + def on_stop_frame_changed(self, value): + # Ensure stop frame is not less than start frame + if hasattr(self, 'start_frame_spin') and value < self.start_frame_spin.value(): + self.start_frame_spin.setValue(value) + + # Update the frame range label + self.update_frame_range_label() + + # Update checkbox state based on whether we have full range + self.update_full_range_checkbox_state() + + def update_frame_range_label(self): + """Update the frame range label with current frame count.""" + if not self.is_mode_shape and hasattr(self, 'frame_range_label'): + start_frame = self.start_frame_spin.value() + stop_frame = self.stop_frame_spin.value() + total_frames = stop_frame - start_frame + 1 + self.frame_range_label.setText(f"Frame Range: ({total_frames} frames)") + + def on_full_range_checkbox_changed(self, state): + """Handle full range checkbox state changes.""" + if not self.is_mode_shape: + if state == QtCore.Qt.CheckState.Checked.value: + # Set to full range + self.start_frame_spin.blockSignals(True) + self.stop_frame_spin.blockSignals(True) + self.start_frame_spin.setValue(0) + self.stop_frame_spin.setValue(self.video.shape[0] - 1) + self.start_frame_spin.blockSignals(False) + self.stop_frame_spin.blockSignals(False) + + # Update the frame range label + self.update_frame_range_label() + + def update_full_range_checkbox_state(self): + """Update the checkbox state based on current spinbox values.""" + if not self.is_mode_shape and hasattr(self, 'full_range_checkbox'): + is_full_range = (self.start_frame_spin.value() == 0 and + self.stop_frame_spin.value() == self.video.shape[0] - 1) + + # Block signals to prevent recursive calls + self.full_range_checkbox.blockSignals(True) + self.full_range_checkbox.setChecked(is_full_range) + self.full_range_checkbox.blockSignals(False) + + def set_full_range(self): + """Set the frame range to cover the full video.""" + if not self.is_mode_shape: + self.start_frame_spin.setValue(0) + self.stop_frame_spin.setValue(self.video.shape[0] - 1) + + def toggle_region_selection(self): + """Toggle region selection mode.""" + self.region_selection_active = self.region_select_button.isChecked() + + if self.region_selection_active: + self.region_select_button.setText("Cancel Selection") + self.region_select_button.setStyleSheet("background-color: #d73a00;") + # Clear any existing region + self.clear_region_graphics() + else: + self.region_select_button.setText("Select Region") + self.region_select_button.setStyleSheet("") + # Clear any temporary selection graphics + self.clear_region_graphics() + + def clear_region_selection(self): + """Clear the current region selection.""" + self.selected_region = None + self.clear_region_graphics() + self.clear_region_button.setEnabled(False) + self.region_info_label.setText("Full frame will be exported") + + # Reset the selection button if it was active + if self.region_selection_active: + self.region_select_button.setChecked(False) + self.toggle_region_selection() + + def clear_region_graphics(self): + """Remove region selection graphics from the view.""" + if self.region_rect is not None: + self.viewbox.removeItem(self.region_rect) + self.region_rect = None + if self.region_overlay is not None: + self.viewbox.removeItem(self.region_overlay) + self.region_overlay = None + + def update_region_selection(self, start_point, current_point): + """Update the region selection rectangle during dragging.""" + if start_point is None or current_point is None: + return + + # Clear previous rectangle + if self.region_rect is not None: + self.viewbox.removeItem(self.region_rect) + + # Create new rectangle + x1, y1 = start_point.x(), start_point.y() + x2, y2 = current_point.x(), current_point.y() + + # Ensure proper ordering + x_min, x_max = min(x1, x2), max(x1, x2) + y_min, y_max = min(y1, y2), max(y1, y2) + + # Create rectangle item + self.region_rect = pg.RectROI([x_min, y_min], [x_max - x_min, y_max - y_min], + pen=pg.mkPen(color='red', width=2), + movable=False, removable=False) + self.viewbox.addItem(self.region_rect) + + def finish_region_selection(self, start_point, end_point): + """Finish region selection and apply overlay.""" + if start_point is None or end_point is None: + return + + # Calculate region bounds + x1, y1 = start_point.x(), start_point.y() + x2, y2 = end_point.x(), end_point.y() + + # Ensure proper ordering and clip to image bounds + video_height, video_width = self.video[0].shape + x_min = max(0, min(x1, x2)) + x_max = min(video_width, max(x1, x2)) + y_min = max(0, min(y1, y2)) + y_max = min(video_height, max(y1, y2)) + + # Store the selected region + self.selected_region = (int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)) + + # Update UI + self.region_select_button.setChecked(False) + self.toggle_region_selection() + self.clear_region_button.setEnabled(True) + self.region_info_label.setText(f"Region: {self.selected_region[2]}x{self.selected_region[3]} pixels") + + # Create overlay effect + self.create_region_overlay() + + def create_region_overlay(self): + """Create a semi-transparent overlay outside the selected region.""" + if self.selected_region is None: + return + + # Clear existing overlay + if self.region_overlay is not None: + self.viewbox.removeItem(self.region_overlay) + + # Create overlay using ImageItem with alpha channel + video_height, video_width = self.video[0].shape + overlay = np.zeros((video_height, video_width, 4), dtype=np.uint8) + + # Set alpha to 128 (semi-transparent) for the entire overlay + overlay[:, :, 3] = 128 + + # Make the selected region fully transparent + x, y, w, h = self.selected_region + overlay[y:y+h, x:x+w, 3] = 0 + + # Create ImageItem for overlay + self.region_overlay = pg.ImageItem(overlay.transpose((1, 0, 2))) + self.viewbox.addItem(self.region_overlay) + + def next_frame(self): + if self.is_mode_shape: + self.current_frame = (self.current_frame + 1) % int(self.fps * self.time_per_period) + else: + self.current_frame = (self.current_frame + 1) % self.video.shape[0] + + self.slider.setValue(self.current_frame) + + def on_slider(self, val): + self.current_frame = val + self.frame_spinbox.setValue(val) + self.slider.setValue(val) + self.update_frame() + + def update_frame(self): + scale = self.magnification + + if self.is_mode_shape: + frame = self.video[0] + self.img_item.setImage(frame.T) + + # Calculate time for sinusoidal motion + t = self.current_frame / self.fps # Convert frame to time in seconds + + # Calculate displacement amplitude using sinusoidal motion + displ_raw = self.displacements + amplitude = np.abs(displ_raw) + phase = np.angle(displ_raw) + + # Calculate displacement using sinusoidal motion + displ = scale * amplitude * np.sin(2 * np.pi * t - phase) + + else: + # Regular displacement animation + frame = self.video[self.current_frame] + self.img_item.setImage(frame.T) + + displ = self.displacements[:, self.current_frame, :] * scale + + # Update scatter plot with displaced points + displaced_pts = self.grid + displ + self.scatter.setData(pos=displaced_pts[:, [0, 1]]) + + if self.arrows_checkbox.isChecked(): + self.scatter.setVisible(False) + + displ = displaced_pts - self.grid + magnitudes = np.linalg.norm(displ, axis=1) + + norm = mcolors.Normalize(vmin=0, vmax=self.disp_max*scale) + cmap = plt.colormaps[self.colormap] # Use the specified colormap + + # Clear old shafts + for shaft in self.arrow_shafts: + self.viewbox.removeItem(shaft) + self.arrow_shafts.clear() + + # Add colored shafts + for pt0, pt1, mag in zip(self.grid, displaced_pts, magnitudes): + color = cmap(norm(mag)) + color_rgb = tuple(int(255 * c) for c in color[:3]) + + shaft = pg.PlotCurveItem( + x=[pt0[0], pt1[0]], + y=[pt0[1], pt1[1]], + pen=pg.mkPen(color_rgb, width=self.point_size_spin.value()) + ) + self.arrow_shafts.append(shaft) + self.viewbox.addItem(shaft) + + else: + self.scatter.setVisible(True) + for shaft in self.arrow_shafts: + self.viewbox.removeItem(shaft) + self.arrow_shafts.clear() + + # Ensure region overlay is on top if it exists + if self.region_overlay is not None: + self.viewbox.removeItem(self.region_overlay) + self.viewbox.addItem(self.region_overlay) + + def export_video(self): + """Export the current visualization as a video file with pixel-perfect rendering.""" + try: + import cv2 + except ImportError: + QtWidgets.QMessageBox.warning(self, "Missing Dependency", + "OpenCV (cv2) is required for video export.\n" + "Install it with: pip install opencv-python") + return + + # Get export parameters + export_fps = self.export_fps_spin.value() + + # Get pixel scaling factor from resolution selection + resolution_text = self.export_resolution_combo.currentText() + if "4x pixel scale" in resolution_text: + pixel_scale = 4 # Each video pixel becomes 4x4 pixels in export + elif "2x pixel scale" in resolution_text: + pixel_scale = 2 # Each video pixel becomes 2x2 pixels in export + elif "6x pixel scale" in resolution_text: + pixel_scale = 6 # Each video pixel becomes 6x6 pixels in export + else: # 4K + pixel_scale = 8 # Each video pixel becomes 8x8 pixels in export + + # Calculate export dimensions based on video dimensions and pixel scaling + video_height, video_width = self.video[0].shape + + # Handle region selection + if self.selected_region is not None: + region_x, region_y, region_width, region_height = self.selected_region + export_width = region_width * pixel_scale + export_height = region_height * pixel_scale + else: + export_width = video_width * pixel_scale + export_height = video_height * pixel_scale + + # Use MP4 with high quality settings + default_ext = "mp4" + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + + # File dialog for save location + file_path, _ = QtWidgets.QFileDialog.getSaveFileName( + self, "Export Video", f"displacement_video.{default_ext}", + f"MP4 files (*.{default_ext});;All files (*.*)" + ) + + if not file_path: + return + + # Store current state to restore later + original_frame = self.current_frame + original_timer_active = self.timer.isActive() + if original_timer_active: + self.timer.stop() + + # Calculate total frames for export + if self.is_mode_shape: + duration = self.duration_spin.value() + total_frames = int(export_fps * duration) + start_frame = 0 + stop_frame = total_frames - 1 + else: + start_frame = self.start_frame_spin.value() + stop_frame = self.stop_frame_spin.value() + total_frames = stop_frame - start_frame + 1 + + # Initialize video writer + writer = cv2.VideoWriter(file_path, fourcc, export_fps, (export_width, export_height)) + + if not writer.isOpened(): + QtWidgets.QMessageBox.critical(self, "Export Error", + "Failed to create video writer. Check file path and format.") + return + + # Show progress bar + self.export_progress.setVisible(True) + self.export_progress.setRange(0, total_frames) + self.export_button.setText("Exporting...") + self.export_button.setEnabled(False) + + try: + # Get current visualization parameters + scale = self.mag_spin.value() + show_arrows = self.arrows_checkbox.isChecked() + point_size = self.point_size_spin.value() + + for frame_idx in range(total_frames): + # Update progress + self.export_progress.setValue(frame_idx) + QtWidgets.QApplication.processEvents() # Keep UI responsive + + # Set the current frame + if self.is_mode_shape: + self.current_frame = frame_idx % int(self.fps * self.time_per_period) + # Get base frame for mode shapes + base_frame = self.video[0] + + # Calculate time for sinusoidal motion + t = self.current_frame / self.fps + displ_raw = self.displacements + amplitude = np.abs(displ_raw) + phase = np.angle(displ_raw) + displ = scale * amplitude * np.sin(2 * np.pi * t - phase) + else: + # For regular videos, use the actual frame index within the specified range + actual_frame_idx = start_frame + frame_idx + self.current_frame = actual_frame_idx + base_frame = self.video[actual_frame_idx] + displ = self.displacements[:, actual_frame_idx, :] * scale + + # Create the export frame by scaling the video frame pixel-perfectly + # Convert to RGB for proper color handling + if len(base_frame.shape) == 2: # Grayscale + frame_rgb = np.stack([base_frame, base_frame, base_frame], axis=2) + else: + frame_rgb = base_frame + + # Apply region cropping if selected + if self.selected_region is not None: + region_x, region_y, region_width, region_height = self.selected_region + frame_rgb = frame_rgb[region_y:region_y+region_height, region_x:region_x+region_width] + + # Scale up the frame without interpolation (nearest neighbor) + export_frame = np.repeat(np.repeat(frame_rgb, pixel_scale, axis=0), pixel_scale, axis=1) + + # Calculate displaced points + displaced_pts = self.grid + displ + + # Calculate region offset for coordinate transformation + region_offset_x = 0 + region_offset_y = 0 + if self.selected_region is not None: + region_offset_x, region_offset_y = self.selected_region[0], self.selected_region[1] + + # Draw displacement visualization on the scaled frame + if show_arrows: + # Draw arrows showing displacement + magnitudes = np.linalg.norm(displ, axis=1) + norm = mcolors.Normalize(vmin=0, vmax=self.disp_max*scale) + cmap = plt.colormaps[self.colormap] + + for i, (pt0, pt1, mag) in enumerate(zip(self.grid, displaced_pts, magnitudes)): + # Apply region offset and scale coordinates to export resolution + start_pt = (int((pt0[0] - region_offset_x) * pixel_scale), + int((pt0[1] - region_offset_y) * pixel_scale)) + end_pt = (int((pt1[0] - region_offset_x) * pixel_scale), + int((pt1[1] - region_offset_y) * pixel_scale)) + + # Check if points are within the export frame bounds + if (0 <= start_pt[0] < export_width and 0 <= start_pt[1] < export_height and + 0 <= end_pt[0] < export_width and 0 <= end_pt[1] < export_height): + + # Get color for this magnitude + color = cmap(norm(mag)) + color_bgr = tuple(int(255 * c) for c in color[2::-1]) # Convert RGB to BGR + + # Draw arrow line + cv2.line(export_frame, start_pt, end_pt, color_bgr, + max(1, point_size * pixel_scale // 10)) + + # Draw arrow head + cv2.circle(export_frame, end_pt, max(1, point_size * pixel_scale // 5), + color_bgr, -1) + else: + # Draw points at displaced positions + for pt in displaced_pts: + center = (int((pt[0] - region_offset_x) * pixel_scale), + int((pt[1] - region_offset_y) * pixel_scale)) + + # Check if point is within the export frame bounds + if (0 <= center[0] < export_width and 0 <= center[1] < export_height): + cv2.circle(export_frame, center, max(1, point_size * pixel_scale // 5), + (0, 0, 255), -1) # Red circles + + # Ensure the frame is in the correct format and size + export_frame = np.clip(export_frame, 0, 255).astype(np.uint8) + + # Convert RGB to BGR for OpenCV + if len(export_frame.shape) == 3: + export_frame_bgr = cv2.cvtColor(export_frame, cv2.COLOR_RGB2BGR) + else: + export_frame_bgr = export_frame + + writer.write(export_frame_bgr) + + writer.release() + + # Create success message with frame range info + if self.is_mode_shape: + frame_info = f"Duration: {self.duration_spin.value():.1f}s" + else: + frame_info = f"Frames: {start_frame} to {stop_frame} ({total_frames} total)" + + # Add region info if applicable + region_info = "" + if self.selected_region is not None: + region_info = f"Region: {self.selected_region[2]}x{self.selected_region[3]} pixels\n" + + QtWidgets.QMessageBox.information(self, "Export Complete", + f"Video exported successfully to:\n{file_path}\n" + f"Resolution: {export_width}x{export_height} " + f"(pixel scale: {pixel_scale}x)\n" + f"{region_info}{frame_info}") + + except Exception as e: + import traceback + traceback.print_exc() + QtWidgets.QMessageBox.critical(self, "Export Error", + f"An error occurred during export:\n{str(e)}") + finally: + # Restore original state + self.current_frame = original_frame + self.update_frame() + if original_timer_active: + self.timer.start() + + # Reset UI + self.export_progress.setVisible(False) + self.export_button.setText("Export Video") + self.export_button.setEnabled(True) + writer.release() + + + +if __name__ == "__main__": + n_frames, height, width = 200, 300, 400 + n_points = 100 + frames = np.random.randint(0, 255, size=(n_frames, height, width), dtype=np.uint8) + + # Test with regular time-series displacements + displacements = 2 * (np.random.rand(n_points, n_frames, 2) - 0.5) + + # Test with mode shapes (2D array) + grid = np.stack(np.meshgrid(np.linspace(50, 250, int(np.sqrt(n_points))), + np.linspace(50, 350, int(np.sqrt(n_points)))), axis=-1).reshape(-1, 2)[:n_points] + + # Create a simple mode shape (e.g., first bending mode) + # displacements = np.zeros((n_points, 2)) + # for i, point in enumerate(grid): + # # Simple sinusoidal mode shape in y-direction + # displacements[i, 0] = 5 * np.sin(np.pi * point[0] / width) # y displacement + # displacements[i, 1] = 0 # no x displacement + + # Test mode shape viewer + ResultViewer(frames, displacements, grid) \ No newline at end of file diff --git a/pyidi/selection.py b/pyidi/GUIs/selection.py similarity index 100% rename from pyidi/selection.py rename to pyidi/GUIs/selection.py diff --git a/pyidi/GUIs/subset_selection.py b/pyidi/GUIs/subset_selection.py new file mode 100644 index 0000000..a0aa3b7 --- /dev/null +++ b/pyidi/GUIs/subset_selection.py @@ -0,0 +1,1574 @@ +import sys +import numpy as np +from PyQt6 import QtWidgets, QtCore +from pyqtgraph import GraphicsLayoutWidget, ImageItem, ScatterPlotItem +import pyqtgraph as pg +from matplotlib.path import Path +# import pyidi # Assuming pyidi is a custom module for video handling + +class BrushViewBox(pg.ViewBox): + def __init__(self, parent_gui, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setMouseMode(self.PanMode) + self.parent_gui = parent_gui + + def mouseClickEvent(self, ev): + if self.parent_gui.mode == "selection" and self.parent_gui.method_buttons["Brush"].isChecked(): + if self.parent_gui.ctrl_held: + ev.accept() + self.parent_gui.handle_brush_start(ev) + else: + ev.ignore() + else: + super().mouseClickEvent(ev) + + def mouseDragEvent(self, ev, axis=None): + # Handle gradient direction selection + if self.parent_gui.mode == "filter" and self.parent_gui.setting_direction: + if ev.isStart(): + pos = ev.scenePos() + if self.sceneBoundingRect().contains(pos): + point = self.mapSceneToView(pos) + self.parent_gui.gradient_direction_points = [(point.x(), point.y())] + self.parent_gui.gradient_direction_start = (point.x(), point.y()) + ev.accept() + return + elif ev.isFinish(): + pos = ev.scenePos() + if self.sceneBoundingRect().contains(pos): + point = self.mapSceneToView(pos) + if hasattr(self.parent_gui, 'gradient_direction_start'): + self.parent_gui.gradient_direction_points = [ + self.parent_gui.gradient_direction_start, + (point.x(), point.y()) + ] + self.parent_gui.compute_direction_vector() + self.parent_gui.update_direction_line() + # Toggle off the direction selection mode + self.parent_gui.direction_button.setChecked(False) + self.parent_gui.set_gradient_direction_mode() + self.parent_gui.compute_candidate_points_gradient_direction() + ev.accept() + return + else: + # During drag, update the line display + pos = ev.scenePos() + if self.sceneBoundingRect().contains(pos): + point = self.mapSceneToView(pos) + if hasattr(self.parent_gui, 'gradient_direction_start'): + temp_points = [ + self.parent_gui.gradient_direction_start, + (point.x(), point.y()) + ] + xs = [p[0] for p in temp_points] + ys = [p[1] for p in temp_points] + self.parent_gui.direction_line.setData(xs, ys) + ev.accept() + return + + if self.parent_gui.mode == "selection" and self.parent_gui.method_buttons["Brush"].isChecked(): + if self.parent_gui.ctrl_held: + ev.accept() + if ev.isStart(): + self.parent_gui._painting = True + self.parent_gui._brush_path = [] + self.parent_gui.handle_brush_start(ev) + elif ev.isFinish(): + self.parent_gui._painting = False + self.parent_gui.handle_brush_end(ev) + else: + self.parent_gui.handle_brush_move(ev) + return + # fallback: pan + super().mouseDragEvent(ev, axis) + +class SelectionGUI(QtWidgets.QMainWindow): + def __init__(self, video): + """Initialize the selection GUI for manual subset selection. + + To extract the points, use the ``get_points`` method or the ``points`` attribute. + + Parameters + ---------- + video : VideoReader or np.ndarray + The video to be analyzed. If a VideoReader object, it should be initialized with the video file. + """ + app = QtWidgets.QApplication.instance() + if app is None: + app = QtWidgets.QApplication([]) + + super().__init__() + + self.setWindowTitle("ROI Selection Tool") + self.resize(1200, 800) + + self._paint_mask = None # Same shape as the image + self._paint_radius = 10 # pixels + self.ctrl_held = False + self.brush_deselect_mode = False + self.installEventFilter(self) + + self.gradient_direction_points = [] + self.gradient_direction = None + self.setting_direction = False + + self.selected_points = [] + self.manual_points = [] + self.candidate_points = [] + self.drawing_polygons = [{'points': [], 'roi_points': []}] + self.active_polygon_index = 0 + self.grid_polygons = [{'points': [], 'roi_points': []}] + self.active_grid_index = 0 + self.brush_masks = [] # Store brush masks for recomputation + self.brush_points = [] # Store computed brush points separately + + # Add status bar for instructions + self.statusBar = self.statusBar() + self.statusBar.showMessage("Ready. Select a method to begin.") + + # Central widget + self.central_widget = QtWidgets.QWidget() + self.setCentralWidget(self.central_widget) + # Top-level layout for the central widget + self.main_layout = QtWidgets.QVBoxLayout(self.central_widget) + self.main_layout.setContentsMargins(0, 0, 0, 0) + self.main_layout.setSpacing(0) + + # Toolbar (fixed height) + self.mode_toolbar = QtWidgets.QWidget() + self.mode_toolbar_layout = QtWidgets.QHBoxLayout(self.mode_toolbar) + self.mode_toolbar_layout.setContentsMargins(5, 4, 5, 4) + self.mode_toolbar_layout.setSpacing(6) + + self.selection_mode_button = QtWidgets.QPushButton("Select") # Selection mode + self.filter_mode_button = QtWidgets.QPushButton("Filter") # Filter mode + for btn in [self.selection_mode_button, self.filter_mode_button]: + btn.setCheckable(True) + btn.setMinimumWidth(100) + self.mode_toolbar_layout.addWidget(btn) + + self.selection_mode_button.setChecked(True) + self.selection_mode_button.clicked.connect(lambda: self.switch_mode("selection")) + self.filter_mode_button.clicked.connect(lambda: self.switch_mode("filter")) + + self.mode_toolbar.setSizePolicy(QtWidgets.QSizePolicy.Policy.Expanding, QtWidgets.QSizePolicy.Policy.Fixed) + self.mode_toolbar.setMaximumHeight(self.selection_mode_button.sizeHint().height() + 12) + + self.main_layout.addWidget(self.mode_toolbar) + + # Add splitter directly and stretch it + self.splitter = QtWidgets.QSplitter(QtCore.Qt.Orientation.Horizontal) + self.main_layout.addWidget(self.splitter, stretch=1) + + # Graphics layout for image and points display + self.ui_graphics() + + self.ui_right_menu() + + # Style + self.setStyleSheet(""" + QTabWidget::pane { border: 0; } + QPushButton { + background-color: #444; + color: white; + padding: 6px 12px; + border: 1px solid #555; + border-radius: 4px; + } + QPushButton:checked { + background-color: #0078d7; + border: 1px solid #005bb5; + } + QGroupBox { + font-weight: bold; + border: 2px solid #555; + border-radius: 8px; + margin-top: 15px; + padding-top: 15px; + background-color: #3a3a3a; + color: white; + } + QGroupBox::title { + subcontrol-origin: margin; + subcontrol-position: top left; + left: 15px; + top: 4px; + padding: 2px 10px; + color: #e0e0e0; + background-color: #4a4a4a; + border: 1px solid #666; + border-radius: 4px; + font-size: 11px; + font-weight: bold; + } + """) + + # Connect selection change handler + self.button_group.idClicked.connect(self.method_selected) + + # Connect mouse click + self.pg_widget.scene().sigMouseClicked.connect(self.on_mouse_click) + + # Set the initial image + from ..video_reader import VideoReader + if isinstance(video, VideoReader): + self.frame = video.get_frame(0) + + self.image_item.setImage(self.frame.T) # axis 0 is x, while image axis 0 is y + + # Ensure method-specific widgets are visible on startup + self.method_selected(self.button_group.checkedId()) + # Don't auto-select any filter method - let user choose when needed + + # Set the initial mode + self.switch_mode("selection") # Default to selection mode + + # Start the GUI + self.show() + # Only call sys.exit if not in IPython + if not hasattr(sys, 'ps1'): # Not interactive + sys.exit(app.exec()) + else: + app.exec() # Don't raise SystemExit in IPythonys + + def eventFilter(self, source, event): + if event.type() == QtCore.QEvent.Type.KeyPress: + if event.key() == QtCore.Qt.Key.Key_Control: + self.ctrl_held = True + elif event.type() == QtCore.QEvent.Type.KeyRelease: + if event.key() == QtCore.Qt.Key.Key_Control: + self.ctrl_held = False + return super().eventFilter(source, event) + + def create_help_button(self, tooltip_text: str) -> QtWidgets.QToolButton: + """Create a small '?' help button with a tooltip.""" + button = QtWidgets.QToolButton() + button.setIcon(self.style().standardIcon(QtWidgets.QStyle.StandardPixmap.SP_MessageBoxQuestion)) + button.setToolTip(tooltip_text) + button.setCursor(QtCore.Qt.CursorShape.WhatsThisCursor) + button.setStyleSheet(""" + QToolButton { + border: none; + background: transparent; + padding: 0px; + } + QToolButton:hover { + color: #0078d7; + } + """) + button.setFixedSize(20, 20) + return button + + def ui_graphics(self): + # Image viewer + self.pg_widget = GraphicsLayoutWidget() + self.view = BrushViewBox(parent_gui=self, lockAspect=True, invertY=True) + self.pg_widget.addItem(self.view) + + + self.image_item = ImageItem() + self.polygon_line = pg.PlotDataItem(pen=pg.mkPen('y', width=2)) + self.polygon_points_scatter = ScatterPlotItem(pen=pg.mkPen(None), brush=pg.mkBrush(255, 255, 0, 200), size=6) + self.scatter = ScatterPlotItem(pen=pg.mkPen(None), brush=pg.mkBrush(255, 100, 100, 200), size=8) + self.roi_overlay = ImageItem() + + self.candidate_scatter = ScatterPlotItem( + pen=pg.mkPen(None), + brush=pg.mkBrush(0, 255, 0, 200), + size=6 + ) + self.direction_line = pg.PlotDataItem(pen=pg.mkPen('r', width=2)) + + self.view.addItem(self.image_item) + self.view.addItem(self.polygon_line) + self.view.addItem(self.polygon_points_scatter) + self.view.addItem(self.roi_overlay) # Add scatter for showing square points + self.view.addItem(self.scatter) # Add scatter for showing points + self.view.addItem(self.candidate_scatter) + self.view.addItem(self.direction_line) + + self.splitter.addWidget(self.pg_widget) + + def ui_right_menu(self): + # The right-side menu + self.method_widget = QtWidgets.QWidget() + self.stack = QtWidgets.QStackedLayout(self.method_widget) + + self.manual_widget = QtWidgets.QWidget() + self.manual_layout = QtWidgets.QVBoxLayout(self.manual_widget) + self.stack.addWidget(self.manual_widget) + + self.automatic_widget = QtWidgets.QWidget() + self.automatic_layout = QtWidgets.QVBoxLayout(self.automatic_widget) + self.stack.addWidget(self.automatic_widget) + + self.ui_manual_right_menu() # The manual right menu + + self.ui_auto_right_menu() # The automatic right menu + + # Set the layout and add to splitter + self.splitter.addWidget(self.method_widget) + self.splitter.setStretchFactor(0, 5) # Image area grows more + self.splitter.setStretchFactor(1, 0) # Menu fixed by content + + # Set initial width for right panel + self.method_widget.setMinimumWidth(150) + self.method_widget.setMaximumWidth(600) + self.splitter.setSizes([1000, 300]) # Initial left/right width + + self.automatic_layout.addStretch(1) + + def ui_manual_right_menu(self): + # Number of selected subsets + self.points_label = QtWidgets.QLabel("Selected subsets: 0") + font = self.points_label.font() + font.setPointSize(10) + font.setBold(True) + self.points_label.setFont(font) + + self.manual_layout.addWidget(self.points_label) + + # Method selection group + method_group = QtWidgets.QGroupBox("Selection Methods") + method_layout = QtWidgets.QVBoxLayout(method_group) + + # Method selection buttons + self.button_group = QtWidgets.QButtonGroup(self.method_widget) + self.button_group.setExclusive(True) + + self.method_buttons = {} + method_names = [ + "Grid", + "Manual", + "Along the line", + "Brush", + "Remove point", + ] + for i, name in enumerate(method_names): + button = QtWidgets.QPushButton(name) + button.setCheckable(True) + if i == 0: + button.setChecked(True) # Default selection + self.button_group.addButton(button, i) + method_layout.addWidget(button) + self.method_buttons[name] = button + + self.manual_layout.addWidget(method_group) + + # Subset configuration group + config_group = QtWidgets.QGroupBox("Subset Configuration") + config_layout = QtWidgets.QVBoxLayout(config_group) + + # Subset size input + self.subset_size_layout = QtWidgets.QHBoxLayout() + self.subset_size_layout.addWidget(QtWidgets.QLabel("Subset size:")) + + self.subset_size_spinbox = QtWidgets.QSpinBox() + self.subset_size_spinbox.setRange(1, 1000) + self.subset_size_spinbox.setValue(11) + self.subset_size_spinbox.setAlignment(QtCore.Qt.AlignmentFlag.AlignRight) + self.subset_size_spinbox.setSingleStep(2) + self.subset_size_spinbox.setMinimum(1) + self.subset_size_spinbox.setMaximum(999) + self.subset_size_spinbox.setWrapping(False) + self.subset_size_spinbox.setSuffix("px") + self.subset_size_spinbox.setFixedWidth(80) + self.subset_size_spinbox.valueChanged.connect(self.update_subset_size_from_spinbox) + self.subset_size_layout.addWidget(self.subset_size_spinbox) + + self.subset_size_layout.addStretch() # Push everything to the left + config_layout.addLayout(self.subset_size_layout) + + self.subset_size_slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + self.subset_size_slider.setRange(1, 100) + self.subset_size_slider.setValue(11) + self.subset_size_slider.setSingleStep(1) + self.subset_size_slider.valueChanged.connect(self.update_subset_size_from_slider) + config_layout.addWidget(self.subset_size_slider) + + # Show ROI rectangles + self.show_roi_checkbox = QtWidgets.QCheckBox("Show subsets") + self.show_roi_checkbox.setChecked(True) + self.show_roi_checkbox.stateChanged.connect(self.update_selected_points) + config_layout.addWidget(self.show_roi_checkbox) + + # Clear button + self.clear_button = QtWidgets.QPushButton("Clear selections") + self.clear_button.clicked.connect(self.clear_selection) + config_layout.addWidget(self.clear_button) + + self.manual_layout.addWidget(config_group) + + # Method-specific controls group + method_controls_group = QtWidgets.QGroupBox("Method-Specific Controls") + method_controls_layout = QtWidgets.QVBoxLayout(method_controls_group) + + # Distance between subsets (only visible for Grid and Along the line) + self.distance_layout = QtWidgets.QHBoxLayout() + self.distance_layout.addWidget(QtWidgets.QLabel("Distance between subsets:")) + + self.distance_spinbox = QtWidgets.QSpinBox() + self.distance_spinbox.setRange(-50, 50) + self.distance_spinbox.setSingleStep(1) + self.distance_spinbox.setValue(0) + self.distance_spinbox.setSuffix("px") + self.distance_spinbox.setFixedWidth(80) + self.distance_spinbox.valueChanged.connect(self.update_distance_from_spinbox) + self.distance_layout.addWidget(self.distance_spinbox) + + self.distance_layout.addStretch() # Push everything to the left + + # Create a widget to hold the distance controls + self.distance_widget = QtWidgets.QWidget() + self.distance_widget.setLayout(self.distance_layout) + self.distance_widget.setVisible(False) # Hidden by default + method_controls_layout.addWidget(self.distance_widget) + + self.distance_slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + self.distance_slider.setRange(-50, 50) + self.distance_slider.setSingleStep(1) + self.distance_slider.setValue(0) + self.distance_slider.setVisible(False) + self.distance_slider.valueChanged.connect(self.update_distance_from_slider) + method_controls_layout.addWidget(self.distance_slider) + + # Start new line (only visible in "Along the line" mode) + self.start_new_line_button = QtWidgets.QPushButton("Start new line") + self.start_new_line_button.clicked.connect(self.start_new_line) + self.start_new_line_button.setVisible(False) # Hidden by default + method_controls_layout.addWidget(self.start_new_line_button) + + # Brush mode + self.brush_radius_label = QtWidgets.QLabel(f"Brush radius (px): {self._paint_radius}") + self.brush_radius_label.setVisible(False) # shown only for Brush mode + method_controls_layout.addWidget(self.brush_radius_label) + + self.brush_radius_slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + self.brush_radius_slider.setRange(1, 50) + self.brush_radius_slider.setSingleStep(1) + self.brush_radius_slider.setValue(self._paint_radius) + self.brush_radius_slider.setVisible(False) # shown only for Brush mode + self.brush_radius_slider.valueChanged.connect(lambda val: self.brush_radius_label.setText(f"Brush radius (px): {val}")) + method_controls_layout.addWidget(self.brush_radius_slider) + + self.brush_deselect_button = QtWidgets.QPushButton("Deselect painted area") + self.brush_deselect_button.setCheckable(True) + self.brush_deselect_button.setVisible(False) # shown only for Brush mode + self.brush_deselect_button.clicked.connect(self.activate_brush_deselect) + method_controls_layout.addWidget(self.brush_deselect_button) + + # Polygon manager (visible only for "Along the line") + self.polygon_list = QtWidgets.QListWidget() + self.polygon_list.setVisible(False) + self.polygon_list.currentRowChanged.connect(self.on_polygon_selected) + method_controls_layout.addWidget(self.polygon_list) + + self.delete_polygon_button = QtWidgets.QPushButton("Delete selected polygon") + self.delete_polygon_button.clicked.connect(self.delete_selected_polygon) + self.delete_polygon_button.setVisible(False) + method_controls_layout.addWidget(self.delete_polygon_button) + + # Grid polygon manager + self.grid_list = QtWidgets.QListWidget() + self.grid_list.setVisible(False) + self.grid_list.currentRowChanged.connect(self.on_grid_selected) + method_controls_layout.addWidget(self.grid_list) + + self.delete_grid_button = QtWidgets.QPushButton("Delete selected grid") + self.delete_grid_button.clicked.connect(self.delete_selected_grid) + self.delete_grid_button.setVisible(False) + method_controls_layout.addWidget(self.delete_grid_button) + + self.manual_layout.addWidget(method_controls_group) + + self.manual_layout.addStretch(1) + + def ui_auto_right_menu(self): + self.candidate_count_label = QtWidgets.QLabel("N candidate points: 0") + font = self.candidate_count_label.font() + font.setPointSize(10) + font.setBold(True) + self.candidate_count_label.setFont(font) + + self.automatic_layout.addWidget(self.candidate_count_label) + + # Filter method selection group + filter_method_group = QtWidgets.QGroupBox("Filter Methods") + filter_method_layout = QtWidgets.QVBoxLayout(filter_method_group) + + self.auto_method_group = QtWidgets.QButtonGroup(self.automatic_widget) + self.auto_method_group.setExclusive(True) + + self.auto_method_buttons = {} + method_names = [ + "Shi-Tomasi", + "Gradient in direction", + ] + for i, name in enumerate(method_names): + button = QtWidgets.QPushButton(name) + button.setCheckable(True) + # Don't auto-select any method - let user choose + self.auto_method_group.addButton(button, i) + filter_method_layout.addWidget(button) + self.auto_method_buttons[name] = button + + self.auto_method_group.idClicked.connect(self.auto_method_selected) + + self.automatic_layout.addWidget(filter_method_group) + + # Display options group + display_options_group = QtWidgets.QGroupBox("Display Options") + display_options_layout = QtWidgets.QVBoxLayout(display_options_group) + + # Checkbox to show/hide scatter and ROI overlay + self.show_points_checkbox = QtWidgets.QCheckBox("Show points/ROIs") + self.show_points_checkbox.setChecked(False) + def toggle_points_and_roi(state): + self.roi_overlay.setVisible(state) + self.scatter.setVisible(state) + self.show_points_checkbox.stateChanged.connect(toggle_points_and_roi) + display_options_layout.addWidget(self.show_points_checkbox) + + # Clear the candidates button + self.clear_candidates_button = QtWidgets.QPushButton("Clear candidates") + self.clear_candidates_button.clicked.connect(self.clear_candidates) + display_options_layout.addWidget(self.clear_candidates_button) + + self.automatic_layout.addWidget(display_options_group) + + # Method settings group + method_settings_group = QtWidgets.QGroupBox("Method Settings") + method_settings_layout = QtWidgets.QVBoxLayout(method_settings_group) + + # Shi-Tomasi method settings + self.shi_tomasi_threshold = 10 # Default threshold value + self.threshold_label = QtWidgets.QLabel(f"Threshold: {self.shi_tomasi_threshold}") + self.threshold_label.setVisible(False) + method_settings_layout.addWidget(self.threshold_label) + + self.threshold_slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + self.threshold_slider.setRange(1, 100) + self.threshold_slider.setSingleStep(1) + self.threshold_slider.setValue(self.shi_tomasi_threshold) + self.threshold_slider.setVisible(False) + method_settings_layout.addWidget(self.threshold_slider) + + def update_label_and_recompute(val): + self.threshold_label.setText(f"Threshold: {str(val)}") + self.update_threshold_and_show_shi_tomsi() # Placeholder method + self.threshold_slider.valueChanged.connect(update_label_and_recompute) + + # Gradient in a specified direction settings + self.direction_button = QtWidgets.QPushButton("Set direction on image") + self.direction_button.setVisible(False) + self.direction_button.setCheckable(True) + self.direction_button.clicked.connect(self.set_gradient_direction_mode) + method_settings_layout.addWidget(self.direction_button) + + # Preset direction buttons + preset_layout = QtWidgets.QHBoxLayout() + + self.x_direction_button = QtWidgets.QPushButton("X Direction") + self.x_direction_button.setVisible(False) + self.x_direction_button.clicked.connect(self.set_x_direction_preset) + preset_layout.addWidget(self.x_direction_button) + + self.y_direction_button = QtWidgets.QPushButton("Y Direction") + self.y_direction_button.setVisible(False) + self.y_direction_button.clicked.connect(self.set_y_direction_preset) + preset_layout.addWidget(self.y_direction_button) + + # Create a widget to hold the preset buttons + self.preset_buttons_widget = QtWidgets.QWidget() + self.preset_buttons_widget.setLayout(preset_layout) + self.preset_buttons_widget.setVisible(False) + method_settings_layout.addWidget(self.preset_buttons_widget) + + self.direction_threshold = 10 + self.gradient_thresh_label = QtWidgets.QLabel(f"Threshold (grad): {self.direction_threshold}") + self.gradient_thresh_label.setVisible(False) + method_settings_layout.addWidget(self.gradient_thresh_label) + + self.gradient_thresh_slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + self.gradient_thresh_slider.setRange(1, 100) + self.gradient_thresh_slider.setSingleStep(1) + self.gradient_thresh_slider.setValue(self.direction_threshold) + self.gradient_thresh_slider.setVisible(False) + method_settings_layout.addWidget(self.gradient_thresh_slider) + + def update_direction_thresh(val): + self.gradient_thresh_label.setText(f"Threshold (grad): {val}") + self.update_threshold_and_show_gradient_direction() + self.gradient_thresh_slider.valueChanged.connect(update_direction_thresh) + + self.automatic_layout.addWidget(method_settings_group) + + self.automatic_layout.addStretch(1) + + def auto_method_selected(self, id: int): + # Check if any button is actually checked + if self.auto_method_group.checkedButton() is None: + return + + method_name = list(self.auto_method_buttons.keys())[id] + # print(f"Selected automatic method: {method_name}") + # Here you can switch method behavior, show/hide widgets, etc. + is_shi_tomasi = method_name == "Shi-Tomasi" + is_gradient_dir = method_name == "Gradient in direction" + + # Reset gradient direction selection when switching away from gradient method + if not is_gradient_dir and hasattr(self, 'direction_button') and self.direction_button.isChecked(): + self.direction_button.setChecked(False) + self.set_gradient_direction_mode() + + # Hide direction line when not in gradient direction mode + if not is_gradient_dir and hasattr(self, 'direction_line'): + self.direction_line.clear() + + self.threshold_label.setVisible(is_shi_tomasi) + self.threshold_slider.setVisible(is_shi_tomasi) + + if is_shi_tomasi: + self.compute_candidate_points_shi_tomasi() + + self.direction_button.setVisible(is_gradient_dir) + self.preset_buttons_widget.setVisible(is_gradient_dir) + self.gradient_thresh_label.setVisible(is_gradient_dir) + self.gradient_thresh_slider.setVisible(is_gradient_dir) + self.preset_buttons_widget.setVisible(is_gradient_dir) + self.y_direction_button.setVisible(is_gradient_dir) + self.x_direction_button.setVisible(is_gradient_dir) + + if is_gradient_dir and self.gradient_direction is not None: + self.compute_candidate_points_gradient_direction() + # Show the direction line if we have gradient direction points + if hasattr(self, 'gradient_direction_points') and len(self.gradient_direction_points) == 2: + self.update_direction_line() + + if is_shi_tomasi: + self.show_instruction("Use the threshold slider to filter points.") + elif is_gradient_dir: + self.show_instruction("Click 'Set direction on image' button and drag to define the gradient direction.") + + def show_instruction(self, message: str): + self.statusBar.showMessage(message) + + def method_selected(self, id: int): + method_name = list(self.method_buttons.keys())[id] + # print(f"Selected method: {method_name}") + is_along = method_name == "Along the line" + is_grid = method_name == "Grid" + is_brush = method_name == "Brush" + + show_spacing = is_along or is_grid or is_brush + + self.start_new_line_button.setVisible(is_along or is_grid) + self.polygon_list.setVisible(is_along) + self.delete_polygon_button.setVisible(is_along) + self.grid_list.setVisible(is_grid) + self.delete_grid_button.setVisible(is_grid) + + self.distance_widget.setVisible(show_spacing) + self.distance_slider.setVisible(show_spacing) + + self.brush_deselect_button.setVisible(is_brush) + self.brush_radius_label.setVisible(is_brush) + self.brush_radius_slider.setVisible(is_brush) + + # Show context-sensitive instructions + if is_brush: + self.show_instruction("Hold Ctrl and drag to paint selection area. Use distance slider to control subset spacing.") + elif is_along: + self.show_instruction("Click to add points along the line. Click 'Start new line' to begin a new one.") + elif is_grid: + self.show_instruction("Click to define grid corners. Click 'Start new line' to begin a new grid.") + elif method_name == "Manual": + self.show_instruction("Click to add points manually.") + elif method_name == "Remove point": + self.show_instruction("Click on a point to remove it.") + else: + self.show_instruction("Ready.") + + def switch_mode(self, mode: str): + self.mode = mode + if mode == "selection": + self.selection_mode_button.setChecked(True) + self.filter_mode_button.setChecked(False) + self.stack.setCurrentWidget(self.manual_widget) + + # Reset gradient direction selection when leaving filter mode + if hasattr(self, 'direction_button') and self.direction_button.isChecked(): + self.direction_button.setChecked(False) + self.set_gradient_direction_mode() + + # Hide direction line when leaving filter mode + if hasattr(self, 'direction_line'): + self.direction_line.clear() + + self.roi_overlay.setVisible(True) + self.scatter.setVisible(True) + self.show_instruction("Selection mode: choose a method on the left.") + + elif mode == "filter": + self.selection_mode_button.setChecked(False) + self.filter_mode_button.setChecked(True) + self.stack.setCurrentWidget(self.automatic_widget) + + # Don't automatically compute anything - let user select method first + self.show_points_checkbox.setChecked(False) + self.roi_overlay.setVisible(False) + self.scatter.setVisible(False) + self.show_instruction("Filter mode: choose a filter method and adjust settings.") + + def on_mouse_click(self, event): + if self.mode == "filter": + return + + if self.method_buttons["Manual"].isChecked(): + self.handle_manual_selection(event) + elif self.method_buttons["Along the line"].isChecked(): + self.handle_polygon_drawing(event) + elif self.method_buttons["Grid"].isChecked(): + self.handle_grid_drawing(event) + elif self.method_buttons["Remove point"].isChecked(): + self.handle_remove_point(event) + elif self.method_buttons["Brush"].isChecked(): + self.handle_brush_start(event) + + def update_selected_points(self): + polygon_points = [pt for poly in self.drawing_polygons for pt in poly['roi_points']] + grid_points = [pt for g in self.grid_polygons for pt in g['roi_points']] + self.selected_points = self.manual_points + polygon_points + grid_points + self.brush_points + + if not self.selected_points: + self.scatter.clear() + self.roi_overlay.clear() + return + + subset_size = self.subset_size_spinbox.value() + half = subset_size // 2 + + # selected_points = np.round(np.array(self.selected_points) - 0.5) + selected_points = np.array(self.selected_points) + + # --- Rectangles --- + if self.show_roi_checkbox.isChecked(): + h, w = self.image_item.image.shape[:2] + overlay = np.zeros((h, w, 4), dtype=np.uint8) # RGBA + + for y, x in selected_points: + x0 = int(round(x - half)) + y0 = int(round(y - half)) + x1 = int(round(x + half+1)) + y1 = int(round(y + half+1)) + + # Ensure bounds + if x0 < 0 or y0 < 0 or x1 >= w or y1 >= h: + continue + + # Fill interior (semi-transparent green) + overlay[y0:y1, x0:x1, 1] = 180 # green + overlay[y0:y1, x0:x1, 3] = 40 # alpha + + # Outline (more opaque green) + overlay[y0, x0:x1, 1] = 255 # top + overlay[y1 - 1, x0:x1, 1] = 255 # bottom + overlay[y0:y1, x0, 1] = 255 # left + overlay[y0:y1, x1 - 1, 1] = 255 # right + + overlay[y0, x0:x1, 3] = 150 + overlay[y1 - 1, x0:x1, 3] = 150 + overlay[y0:y1, x0, 3] = 150 + overlay[y0:y1, x1 - 1, 3] = 150 + + self.roi_overlay.setImage(overlay, autoLevels=False) + self.roi_overlay.setZValue(1) + else: + self.roi_overlay.clear() + + # --- Center Dots --- + self.scatter.setData( + pos=selected_points + 0.5, + symbol='o', + size=6, + brush=pg.mkBrush(255, 100, 100, 200), + pen=pg.mkPen(None) + ) + self.points_label.setText(f"Selected subsets: {len(self.selected_points)}") + + def update_distance_from_slider(self, value): + """Update distance spinbox from slider value and recompute ROI points.""" + # Update spinbox without triggering its signal + self.distance_spinbox.blockSignals(True) + self.distance_spinbox.setValue(value) + self.distance_spinbox.blockSignals(False) + + # Recompute ROI points + self.recompute_roi_points() + + def update_distance_from_spinbox(self, value): + """Update distance slider from spinbox value and recompute ROI points.""" + # Update slider without triggering its signal + self.distance_slider.blockSignals(True) + self.distance_slider.setValue(value) + self.distance_slider.blockSignals(False) + + # Recompute ROI points + self.recompute_roi_points() + + def update_subset_size_from_slider(self, value): + """Update subset size spinbox from slider value and recompute ROI points.""" + # Update spinbox without triggering its signal + self.subset_size_spinbox.blockSignals(True) + self.subset_size_spinbox.setValue(value) + self.subset_size_spinbox.blockSignals(False) + + # Recompute ROI points and update display + self.recompute_roi_points() + + def update_subset_size_from_spinbox(self, value): + """Update subset size slider from spinbox value and recompute ROI points.""" + # Update slider, clamping to its range + slider_value = min(100, max(1, value)) + self.subset_size_slider.blockSignals(True) + self.subset_size_slider.setValue(slider_value) + self.subset_size_slider.blockSignals(False) + + # Recompute ROI points and update display + self.recompute_roi_points() + + def recompute_roi_points(self): + subset_size = self.subset_size_spinbox.value() + spacing = self.distance_spinbox.value() + + # Update all "along the line" polygons + for poly in self.drawing_polygons: + if len(poly['points']) >= 2: + poly['roi_points'] = points_along_polygon(poly['points'], subset_size, spacing) + + # Update all "grid" polygons + for grid in self.grid_polygons: + if len(grid['points']) >= 3: + grid['roi_points'] = rois_inside_polygon(grid['points'], subset_size, spacing) + + # Update all brush masks + self.brush_points = [] + for mask in self.brush_masks: + self.brush_points.extend(rois_inside_mask(mask, subset_size, spacing)) + + self.update_selected_points() + + def start_new_line(self): + # print("Starting a new line...") + + if self.method_buttons["Along the line"].isChecked(): + self.drawing_polygons.append({'points': [], 'roi_points': []}) + self.active_polygon_index = len(self.drawing_polygons) - 1 + self.polygon_list.addItem(f"Polygon {self.active_polygon_index + 1}") + self.polygon_list.setCurrentRow(self.active_polygon_index) + self.update_polygon_display() + + elif self.method_buttons["Grid"].isChecked(): + self.grid_polygons.append({'points': [], 'roi_points': []}) + self.active_grid_index = len(self.grid_polygons) - 1 + self.grid_list.addItem(f"Grid {self.active_grid_index + 1}") + self.grid_list.setCurrentRow(self.active_grid_index) + self.update_grid_display() + + self.update_selected_points() + + def clear_selection(self): + # print("Clearing selections...") + + # Clear manual points + self.manual_points = [] + + # Clear brush data + self.brush_masks = [] + self.brush_points = [] + + # Clear line-based polygons + self.drawing_polygons = [{'points': [], 'roi_points': []}] + self.polygon_list.clear() + self.polygon_list.addItem("Polygon 1") + self.polygon_list.setCurrentRow(0) + self.active_polygon_index = 0 + self.polygon_line.clear() + self.polygon_points_scatter.clear() + + # Clear grid-based polygons + self.grid_polygons = [{'points': [], 'roi_points': []}] + self.grid_list.clear() + self.grid_list.addItem("Grid 1") + self.grid_list.setCurrentRow(0) + self.active_grid_index = 0 + + if hasattr(self, 'grid_line'): + self.grid_line.clear() + if hasattr(self, 'grid_points_scatter'): + self.grid_points_scatter.clear() + + # Clear selected points and visual overlays + self.selected_points = [] + + if hasattr(self, 'scatter'): + self.scatter.clear() + if hasattr(self, 'roi_overlay'): + self.roi_overlay.clear() + + # Clear candidate points from automatic filtering + self.clear_candidates() + + self.points_label.setText("Selected subsets: 0") + + # Reset gradient direction selection and clear direction line + if hasattr(self, 'direction_button') and self.direction_button.isChecked(): + self.direction_button.setChecked(False) + self.set_gradient_direction_mode() + self.direction_line.clear() + + self.update_selected_points() # Refresh display + + def set_image(self, img: np.ndarray): + """Display image in the manual tab.""" + self.image_item.setImage(img) + + def get_points(self): + """Get all selected points from manual and polygons.""" + filtered_points = self.get_filtered_points() + if filtered_points.size > 0: + return filtered_points + else: + return self.get_selected_points() + + @property + def points(self): + return self.get_points() + + def get_filtered_points(self): + """Get candidate points from filtering.""" + return np.array(self.candidate_points)[:, ::-1] if hasattr(self, 'candidate_points') else [] + + def get_selected_points(self): + """Get all selected points from manual, polygons and grid.""" + return np.array(self.selected_points)[:, ::-1] if self.selected_points else [] + + # Grid selection + def handle_grid_drawing(self, event): + pos = event.scenePos() + if self.view.sceneBoundingRect().contains(pos): + mouse_point = self.view.mapSceneToView(pos) + x, y = mouse_point.x(), mouse_point.y() + + # Add first grid polygon to the list if not yet shown + if self.grid_list.count() == 0: + self.grid_list.addItem("Grid 1") + self.grid_list.setCurrentRow(0) + + grid = self.grid_polygons[self.active_grid_index] + grid['points'].append((x, y)) + + # Compute ROI points only if closed polygon + if len(grid['points']) >= 3: + subset_size = self.subset_size_spinbox.value() + spacing = self.distance_spinbox.value() + grid['roi_points'] = rois_inside_polygon(grid['points'], subset_size, spacing) + + self.update_grid_display() + self.update_selected_points() + + def on_grid_selected(self, index): + if 0 <= index < len(self.grid_polygons): + self.active_grid_index = index + + def delete_selected_grid(self): + row = self.grid_list.currentRow() + if row >= 0 and len(self.grid_polygons) > 1: + del self.grid_polygons[row] + self.grid_list.takeItem(row) + self.active_grid_index = max(0, row - 1) + self.grid_list.setCurrentRow(self.active_grid_index) + self.update_grid_display() + self.update_selected_points() + + def update_grid_display(self): + # Combine all points from all grid polygons for scatter + all_pts = [pt for poly in self.grid_polygons for pt in poly['points']] + + # Create or update scatter plot for grid polygon vertices + if not hasattr(self, 'grid_points_scatter'): + self.grid_points_scatter = ScatterPlotItem( + pen=pg.mkPen(None), + brush=pg.mkBrush(255, 200, 0, 200), + size=6 + ) + self.view.addItem(self.grid_points_scatter) + self.grid_points_scatter.setData(pos=all_pts) + + # Combine all polygon outlines with np.nan-separated segments + xs, ys = [], [] + for poly in self.grid_polygons: + path = poly['points'] + if len(path) >= 2: + xs.extend([p[0] for p in path] + [path[0][0], np.nan]) # Close polygon + ys.extend([p[1] for p in path] + [path[0][1], np.nan]) + elif len(path) == 1: + xs.extend([path[0][0], path[0][0], np.nan]) + ys.extend([path[0][1], path[0][1], np.nan]) + + # Create or update line plot for polygon outlines + if not hasattr(self, 'grid_line'): + self.grid_line = pg.PlotDataItem( + pen=pg.mkPen('c', width=2) # Cyan line + ) + self.view.addItem(self.grid_line) + self.grid_line.setData(xs, ys) + + # Manual selection + def handle_manual_selection(self, event): + """Handle manual selection of points.""" + pos = event.scenePos() + if self.view.sceneBoundingRect().contains(pos): + mouse_point = self.view.mapSceneToView(pos) + x, y = mouse_point.x(), mouse_point.y() + x_int, y_int = round(x-0.5), round(y-0.5) + self.manual_points.append((x_int, y_int)) + self.update_selected_points() + + # Along the line selection + def handle_polygon_drawing(self, event): + pos = event.scenePos() + if self.view.sceneBoundingRect().contains(pos): + mouse_point = self.view.mapSceneToView(pos) + x, y = mouse_point.x(), mouse_point.y() + + # Add first polygon to the list if not yet shown + if self.polygon_list.count() == 0: + self.polygon_list.addItem("Polygon 1") + self.polygon_list.setCurrentRow(0) + + poly = self.drawing_polygons[self.active_polygon_index] + poly['points'].append((x, y)) + + # Update ROI points only for this polygon + if len(poly['points']) >= 2: + subset_size = self.subset_size_spinbox.value() + spacing = self.distance_spinbox.value() + poly['roi_points'] = points_along_polygon(poly['points'], subset_size, spacing) + + self.update_polygon_display() + self.update_selected_points() + + def delete_selected_polygon(self): + row = self.polygon_list.currentRow() + if row >= 0 and len(self.drawing_polygons) > 1: + del self.drawing_polygons[row] + self.polygon_list.takeItem(row) + self.active_polygon_index = max(0, row - 1) + self.polygon_list.setCurrentRow(self.active_polygon_index) + self.update_polygon_display() + self.update_selected_points() + + def update_polygon_display(self): + all_pts = [pt for poly in self.drawing_polygons for pt in poly['points']] + self.polygon_points_scatter.setData(pos=all_pts) + + xs, ys = [], [] + for poly in self.drawing_polygons: + path = poly['points'] + if len(path) >= 2: + xs.extend([p[0] for p in path] + [np.nan]) + ys.extend([p[1] for p in path] + [np.nan]) + elif len(path) == 1: + xs.extend([path[0][0], path[0][0], np.nan]) + ys.extend([path[0][1], path[0][1], np.nan]) + + self.polygon_line.setData(xs, ys) + + def on_polygon_selected(self, index): + if 0 <= index < len(self.drawing_polygons): + self.active_polygon_index = index + + # Remove point selection + def handle_remove_point(self, event): + pos = event.scenePos() + if self.view.sceneBoundingRect().contains(pos): + mouse_point = self.view.mapSceneToView(pos) + x, y = mouse_point.x(), mouse_point.y() + + # Find nearest point + if not self.selected_points: + return + + pts = np.array(self.selected_points) + distances = np.linalg.norm(pts - np.array([x, y]), axis=1) + idx = np.argmin(distances) + closest = tuple(pts[idx]) + + # Remove from manual if present + if closest in self.manual_points: + self.manual_points.remove(closest) + + # Remove from polygons + for poly in self.drawing_polygons: + if closest in poly['roi_points']: + poly['roi_points'].remove(closest) + + # Remove from grid + for grid in self.grid_polygons: + if closest in grid['roi_points']: + grid['roi_points'].remove(closest) + + # Remove from brush points + if closest in self.brush_points: + self.brush_points.remove(closest) + + self.update_selected_points() + + # Automatic filtering + # Shi-Tomasi method + def compute_candidate_points_shi_tomasi(self): + """Compute good feature points using structure tensor analysis (Shi–Tomasi style).""" + from scipy.ndimage import sobel + + subset_size = self.subset_size_spinbox.value() + roi_size = subset_size // 2 + + img = self.image_item.image.astype(np.float32) + candidates = [] + + # All selected points (not just manual) + for row, col in self.selected_points: + y, x = int(round(row)), int(round(col)) + + if (y - roi_size < 0 or y + roi_size + 1 > img.shape[0] or + x - roi_size < 0 or x + roi_size + 1 > img.shape[1]): + continue + + roi = img[y - roi_size: y + roi_size + 1, + x - roi_size: x + roi_size + 1] + + # Compute gradients + gx = sobel(roi, axis=1) + gy = sobel(roi, axis=0) + + Gx2 = np.sum(gx ** 2) + Gy2 = np.sum(gy ** 2) + GxGy = np.sum(gx * gy) + + matrix = np.array([[Gx2, GxGy], + [GxGy, Gy2]]) + + eigvals = np.linalg.eigvalsh(matrix) # sorted ascending + min_eig = eigvals[0] + + candidates.append((x + 0.0, y + 0.0, min_eig)) + + if not candidates: + self.candidate_points = [] + self.update_candidate_display() + return + + # Threshold by normalized eigenvalue + eigvals = np.array([v[2] for v in candidates]) + self.max_eig_shi_tomasi = np.max(eigvals) + + self.candidates_shi_tomasi = candidates + + self.update_threshold_and_show_shi_tomsi() + + def update_threshold_and_show_shi_tomsi(self): + threshold_ratio = self.threshold_slider.value() / 1000.0 + + eig_threshold = self.max_eig_shi_tomasi * threshold_ratio + + self.candidate_points = [(round(y), round(x)) for (x, y, e) in self.candidates_shi_tomasi if e > eig_threshold] + self.update_candidate_display() + self.update_candidate_points_count() + + def update_candidate_points_count(self): + """Update the displayed count of candidate points.""" + if self.candidate_points: + count_text = f"N candidate points: {len(self.candidate_points)}" + else: + count_text = "N candidate points: 0" + + self.candidate_count_label.setText(count_text) + + def update_candidate_display(self): + """Show candidate points as scatter dots on the image.""" + if not hasattr(self, 'candidate_scatter'): + self.candidate_scatter = ScatterPlotItem( + pen=pg.mkPen(None), + brush=pg.mkBrush(0, 255, 0, 150), # green with transparency + size=6, + symbol='o' + ) + self.view.addItem(self.candidate_scatter) + + if self.candidate_points: + self.candidate_scatter.setData(pos=np.array(self.candidate_points) + 0.5) + else: + self.candidate_scatter.clear() + + def clear_candidates(self): + """Clear candidate points.""" + # print("Clearing candidate points...") + self.candidate_points = [] + self.update_candidate_points_count() + if hasattr(self, 'candidate_scatter'): + self.candidate_scatter.clear() + + # Reset gradient direction selection when clearing candidates + if hasattr(self, 'direction_button') and self.direction_button.isChecked(): + self.direction_button.setChecked(False) + self.set_gradient_direction_mode() + + self.update_selected_points() # Update main display to remove candidates + + # Gradient in a specified direction + def set_gradient_direction_mode(self): + """Toggle gradient direction selection mode.""" + self.setting_direction = self.direction_button.isChecked() + + if self.setting_direction: + self.direction_button.setText("Cancel Direction") + self.direction_button.setStyleSheet("background-color: #d73a00;") + self.gradient_direction_points = [] + # Clear the direction line only when starting new selection + self.direction_line.clear() + self.show_instruction("Click and drag to set the gradient direction.") + else: + self.direction_button.setText("Set direction on image") + self.direction_button.setStyleSheet("") + # Don't clear the direction line when finishing selection - keep it visible + self.show_instruction("Filter mode: choose a filter method and adjust settings.") + + def compute_direction_vector(self): + p1, p2 = self.gradient_direction_points + dx, dy = p2[0] - p1[0], p2[1] - p1[1] + norm = np.sqrt(dx**2 + dy**2) + if norm == 0: + self.gradient_direction = None + else: + self.gradient_direction = (dx / norm, dy / norm) + + def compute_candidate_points_gradient_direction(self): + from scipy.ndimage import sobel + + if self.gradient_direction is None: + return + + dy, dx = self.gradient_direction + subset_size = self.subset_size_spinbox.value() + roi_size = subset_size // 2 + + img = self.image_item.image.astype(np.float32) + candidates = [] + + for row, col in self.selected_points: + y, x = int(round(row)), int(round(col)) + + if (y - roi_size < 0 or y + roi_size + 1 > img.shape[0] or + x - roi_size < 0 or x + roi_size + 1 > img.shape[1]): + continue + + roi = img[y - roi_size: y + roi_size + 1, + x - roi_size: x + roi_size + 1] + + gx = sobel(roi, axis=1) + gy = sobel(roi, axis=0) + + gdir = np.abs(gx * dx) + np.abs(gy * dy) + strength = np.sum(np.abs(gdir)) + + candidates.append((x + 0.0, y + 0.0, strength)) + + if not candidates: + self.candidate_points = [] + self.update_candidate_display() + return + + values = np.array([v[2] for v in candidates]) + self.max_grad_dir = np.max(values) + self.candidates_grad_dir = candidates + self.update_threshold_and_show_gradient_direction() + + def update_threshold_and_show_gradient_direction(self): + threshold_ratio = self.gradient_thresh_slider.value() / 100.0 + threshold = self.max_grad_dir * threshold_ratio + + self.candidate_points = [ + (round(y), round(x)) + for (x, y, v) in self.candidates_grad_dir + if v > threshold + ] + self.update_candidate_display() + self.update_candidate_points_count() + + def update_direction_line(self): + if len(self.gradient_direction_points) == 2: + xs = [p[0] for p in self.gradient_direction_points] + ys = [p[1] for p in self.gradient_direction_points] + self.direction_line.setData(xs, ys) + else: + self.direction_line.clear() + + # Brush + def handle_brush_start(self, ev): + if self.image_item.image is None: + return + h, w = self.image_item.image.shape[:2] + self._paint_mask = np.zeros((h, w), dtype=bool) + self.handle_brush_move(ev) + + def handle_brush_move(self, ev): + if self._paint_mask is None: + return + + pos = ev.pos() + if self.view.sceneBoundingRect().contains(pos): + mouse_point = self.view.mapSceneToView(pos) + y, x = int(round(mouse_point.x())), int(round(mouse_point.y())) + r = self.brush_radius_slider.value() + + h, w = self._paint_mask.shape + yy, xx = np.ogrid[max(0, y - r):min(h, y + r + 1), + max(0, x - r):min(w, x + r + 1)] + mask = (yy - y) ** 2 + (xx - x) ** 2 <= r ** 2 + self._paint_mask[max(0, y - r):min(h, y + r + 1), + max(0, x - r):min(w, x + r + 1)][mask] = True + + self.update_brush_overlay() + + def handle_brush_end(self, ev): + if self._paint_mask is None: + return + + subset_size = self.subset_size_spinbox.value() + spacing = self.distance_spinbox.value() + + # Generate (row, col) points inside the painted mask + brush_rois = rois_inside_mask(self._paint_mask, subset_size, spacing) + + if self.brush_deselect_mode: + def point_inside_mask(pt, mask): + y, x = int(round(pt[0])), int(round(pt[1])) + h, w = mask.shape + return 0 <= y < h and 0 <= x < w and mask[y, x] + + # Remove from manual points + self.manual_points = [ + pt for pt in self.manual_points + if not point_inside_mask(pt, self._paint_mask) + ] + + # Remove from polygons + for poly in self.drawing_polygons: + poly['roi_points'] = [pt for pt in poly['roi_points'] if not point_inside_mask(pt, self._paint_mask)] + + # Remove from grid polygons + for grid in self.grid_polygons: + grid['roi_points'] = [pt for pt in grid['roi_points'] if not point_inside_mask(pt, self._paint_mask)] + + # Remove from brush points + self.brush_points = [ + pt for pt in self.brush_points + if not point_inside_mask(pt, self._paint_mask) + ] + + # Remove brush masks that are covered by the current mask + self.brush_masks = [mask for mask in self.brush_masks + if not np.any(mask & self._paint_mask)] + + self.brush_deselect_mode = False + self.brush_deselect_button.setChecked(False) + + else: + # Store the mask for future recomputation + self.brush_masks.append(self._paint_mask.copy()) + # Add points to brush_points + self.brush_points.extend(brush_rois) + + self._paint_mask = None + self.update_selected_points() + self.update_brush_overlay() + + def update_brush_overlay(self): + if not hasattr(self, 'brush_overlay'): + self.brush_overlay = ImageItem() + self.view.addItem(self.brush_overlay) + + if self._paint_mask is not None: + rgba = np.zeros((*self._paint_mask.shape, 4), dtype=np.uint8) + if self.brush_deselect_mode: + rgba[self._paint_mask] = [255, 0, 0, 80] # Red with transparency + else: + rgba[self._paint_mask] = [0, 200, 255, 80] # Cyan with transparency + self.brush_overlay.setImage(rgba, autoLevels=False) + self.brush_overlay.setZValue(2) + else: + self.brush_overlay.clear() + + def activate_brush_deselect(self): + if self.brush_deselect_button.isChecked(): + self.brush_deselect_mode = True + + def set_x_direction_preset(self): + """Set horizontal (X) direction preset.""" + if self.image_item.image is None: + return + + # Get image dimensions + w, h = self.image_item.image.shape[:2] + + # Set horizontal line in the center of the image + center_y = h // 2 + margin = min(w // 4, 50) # Use 1/4 of width or 50 pixels, whichever is smaller + + # Create horizontal line points + start_x = margin + end_x = w - margin + + self.gradient_direction_points = [ + (start_x, center_y), + (end_x, center_y) + ] + + # Compute and set the direction vector + self.compute_direction_vector() + self.update_direction_line() + + # Ensure direction selection is off + if self.direction_button.isChecked(): + self.direction_button.setChecked(False) + self.set_gradient_direction_mode() + + # Compute candidate points + self.compute_candidate_points_gradient_direction() + + self.show_instruction("X (horizontal) direction preset applied.") + + def set_y_direction_preset(self): + """Set vertical (Y) direction preset.""" + if self.image_item.image is None: + return + + # Get image dimensions + w, h = self.image_item.image.shape[:2] + + # Set vertical line in the center of the image + center_x = w // 2 + margin = min(h // 4, 50) # Use 1/4 of height or 50 pixels, whichever is smaller + + # Create vertical line points + start_y = margin + end_y = h - margin + + self.gradient_direction_points = [ + (center_x, start_y), + (center_x, end_y) + ] + + # Compute and set the direction vector + self.compute_direction_vector() + self.update_direction_line() + + # Ensure direction selection is off + if self.direction_button.isChecked(): + self.direction_button.setChecked(False) + self.set_gradient_direction_mode() + + # Compute candidate points + self.compute_candidate_points_gradient_direction() + + self.show_instruction("Y (vertical) direction preset applied.") + +def points_along_polygon(polygon, subset_size, spacing=0): + if len(polygon) < 2: + return [] + + step = subset_size + spacing + if step <= 0: + step = 1 + + result_points = [] + + for i in range(len(polygon) - 1): + p1 = np.array(polygon[i]) + p2 = np.array(polygon[i + 1]) + segment = p2 - p1 + length = np.linalg.norm(segment) + + if length == 0: + continue + + direction = segment / length + n_points = int(length // step) + + for j in range(n_points + 1): + pt = p1 + j * step * direction + result_points.append((round(pt[0] - 0.5), round(pt[1] - 0.5))) + + return result_points + +def rois_inside_polygon(polygon, subset_size, spacing): + if len(polygon) < 3: + return [] + + polygon = np.array(polygon) + min_x, max_x = int(np.floor(np.min(polygon[:, 0]))), int(np.ceil(np.max(polygon[:, 0]))) + min_y, max_y = int(np.floor(np.min(polygon[:, 1]))), int(np.ceil(np.max(polygon[:, 1]))) + + step = subset_size + spacing + if step <= 0: + step = 1 # minimum step to avoid infinite loop + xs = np.arange(min_x, max_x+1, step) + ys = np.arange(min_y, max_y+1, step) + + grid_x, grid_y = np.meshgrid(xs, ys) + points = np.vstack([grid_x.ravel(), grid_y.ravel()]).T + + mask = Path(polygon).contains_points(points) + return [tuple(p) for p in points[mask]] + +def rois_inside_mask(mask, subset_size, spacing): + step = subset_size + spacing + if step <= 0: + step = 1 + + h, w = mask.shape + xs = np.arange(0, w, step) + ys = np.arange(0, h, step) + grid_x, grid_y = np.meshgrid(xs, ys) + + candidate_points = np.vstack([grid_y.ravel(), grid_x.ravel()]).T # (y, x) + + # Only keep points where the mask is True + selected = [tuple(p) for p in candidate_points if mask[p[0], p[1]]] + return selected + +if __name__ == "__main__": + # import pyidi + # filename = "data/data_showcase.cih" + # video = pyidi.VideoReader(filename) + # example_image = (video.get_frame(0).T)[:, ::-1] + + + import requests + from PIL import Image + import io + import numpy as np + import matplotlib.pyplot as plt + # Example black and white image (public domain) + url = "https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/camera.png" + # Fetch the image + response = requests.get(url) + img = Image.open(io.BytesIO(response.content)).convert("L") # Convert to grayscale + # Convert to numpy array + example_image = (np.array(img).T)[:, ::-1] + + + Points = SelectionGUI(example_image.astype(np.uint8)) + + print(Points.get_points()) # # print selected points for testing diff --git a/pyidi/__init__.py b/pyidi/__init__.py index 8b954ee..315c361 100644 --- a/pyidi/__init__.py +++ b/pyidi/__init__.py @@ -3,9 +3,8 @@ from .pyidi_legacy import pyIDI from . import tools from . import postprocessing -from .selection import SubsetSelection from .load_analysis import load_analysis from .video_reader import VideoReader from .methods import * -from .gui import GUI +from .GUIs import * from .fiducial import * \ No newline at end of file diff --git a/pyidi/methods/idi_method.py b/pyidi/methods/idi_method.py index 609451b..78d96e1 100644 --- a/pyidi/methods/idi_method.py +++ b/pyidi/methods/idi_method.py @@ -8,7 +8,6 @@ import inspect import matplotlib.pyplot as plt -from ..selection import SubsetSelection from ..video_reader import VideoReader from ..tools import setup_logger @@ -263,6 +262,8 @@ def _make_comparison_dict(self): return settings def set_points(self, points): + from ..GUIs.selection import SubsetSelection + if isinstance(points, list): points = np.array(points) elif isinstance(points, SubsetSelection): diff --git a/pyidi/pyidi.py b/pyidi/pyidi.py index 87eb105..f6bc5d7 100644 --- a/pyidi/pyidi.py +++ b/pyidi/pyidi.py @@ -11,7 +11,7 @@ from .methods import IDIMethod, SimplifiedOpticalFlow, LucasKanade, DirectionalLucasKanade #, LucasKanadeSc, LucasKanadeSc2, GradientBasedOpticalFlow from .video_reader import VideoReader from . import tools -from . import selection +from .GUIs import selection available_method_shortcuts = [ ('sof', SimplifiedOpticalFlow), @@ -266,7 +266,7 @@ def __repr__(self): return rep def gui(self): - from . import gui + from .GUIs import gui self.gui_obj = gui.gui(self) @property