diff --git a/.gitmodules b/.gitmodules index 8111c4f..9a53575 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "third_party/Depth-Anything-ONNX"] path = third_party/Depth-Anything-ONNX url = https://github.com/fabio-sim/Depth-Anything-ONNX.git +[submodule "third_party/sam2"] + path = third_party/sam2 + url = https://github.com/facebookresearch/sam2.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3603777..35edbe7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,6 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml - - id: check-added-large-files - id: check-merge-conflict - repo: https://github.com/psf/black diff --git a/data_engine/Readme.md b/data_engine/Readme.md new file mode 100644 index 0000000..ec28c61 --- /dev/null +++ b/data_engine/Readme.md @@ -0,0 +1,3 @@ +# Data Engine + +For now we have a POC of the project. Data engine is in progress of development. diff --git a/data_engine/poc/.dockerignore b/data_engine/poc/.dockerignore new file mode 100644 index 0000000..6aed783 --- /dev/null +++ b/data_engine/poc/.dockerignore @@ -0,0 +1,118 @@ +# Docker ignore file for SAM2 Data Engine + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Virtual environments +venv/ +env/ +ENV/ +env.bak/ +venv.bak/ + +# Cache directories +cache/ +.cache/ +*.cache + +# IDE and editor files +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store +Thumbs.db + +# Logs +*.log +logs/ + +# Test files +test_output/ +test_config/ +.pytest_cache/ +.coverage +htmlcov/ + +# Model files (large) +*.pt +*.pth +*.onnx +*.trt +models/ + +# Data directories +data/ +datasets/ +videos/ +exports/ +projects/ + +# Docker files (don't copy into container) +Dockerfile* +docker-compose*.yml +.dockerignore + +# Git +.git/ +.gitignore + +# Documentation +docs/ +*.md +!README.md + +# Temporary files +tmp/ +temp/ +.tmp/ + +# OS specific +.DS_Store +Thumbs.db +desktop.ini + +# Jupyter +.ipynb_checkpoints/ +*.ipynb + +# Environment files +.env +.env.local +.env.*.local + +# Large media files +*.mp4 +*.avi +*.mov +*.mkv +*.wmv +*.webm +*.flv + +# Archives +*.zip +*.tar.gz +*.rar +*.7z diff --git a/data_engine/poc/Dockerfile b/data_engine/poc/Dockerfile new file mode 100644 index 0000000..b6183f5 --- /dev/null +++ b/data_engine/poc/Dockerfile @@ -0,0 +1,89 @@ +# SAM2 Data Engine Dockerfile +FROM pytorch/pytorch:2.7.1-cuda12.6-cudnn9-runtime + +# Install system dependencies for GUI applications and git +RUN apt-get update && apt-get install -y \ + libgl1-mesa-glx \ + libglib2.0-0 \ + libxrender1 \ + libxrandr2 \ + libxss1 \ + libxcursor1 \ + libxcomposite1 \ + libasound2 \ + libxi6 \ + libxtst6 \ + libqt5gui5 \ + libqt5core5a \ + libqt5widgets5 \ + qt5-gtk-platformtheme \ + libxcb1 \ + libxcb-xkb1 \ + libxkbcommon-x11-0 \ + libxcb-icccm4 \ + libxcb-image0 \ + libxcb-keysyms1 \ + libxcb-randr0 \ + libxcb-render-util0 \ + libxcb-render0 \ + libxcb-shape0 \ + libxcb-sync1 \ + libxcb-xfixes0 \ + libxcb-xinerama0 \ + libxcb-xinput0 \ + libxcb-cursor0 \ + x11-apps \ + git \ + wget \ + && rm -rf /var/lib/apt/lists/* + +# Set working directory +WORKDIR /app + +# Clone SAM2 repository +RUN git clone https://github.com/facebookresearch/segment-anything-2.git /app/third_party/sam2 + +# Copy requirements first for better caching +COPY requirements.txt . + +# Install Python dependencies including SAM2 +RUN pip install --no-cache-dir -r requirements.txt +RUN cd /app/third_party/sam2 && pip install -e . + +# Create models directory and config directory +RUN mkdir -p /app/models /app/.config /app/sam2_models + +# Set environment variables for Ultralytics +ENV YOLO_CONFIG_DIR=/app/.config +ENV YOLO_CONFIG_DIR=/app/.config + +# Download SAM2 models during build time +RUN python3 -c "\ +import sys; \ +sys.path.append('/app/third_party/sam2'); \ +import os; \ +import urllib.request; \ +os.makedirs('/app/sam2_models', exist_ok=True); \ +checkpoint_urls = { \ + 'sam2_hiera_tiny.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt', \ + 'sam2_hiera_small.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt', \ + 'sam2_hiera_base_plus.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt', \ + 'sam2_hiera_large.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt' \ +}; \ +[urllib.request.urlretrieve(url, f'/app/sam2_models/{filename}') for filename, url in checkpoint_urls.items() if not os.path.exists(f'/app/sam2_models/{filename}')]; \ +print('All SAM2 models downloaded')" + +# Download YOLOv8 models during build time (keeping YOLO for dataset export) +RUN python -c "from ultralytics import YOLO; model = YOLO('yolov8n.pt'); print('YOLOv8 nano model downloaded')" +RUN python -c "from ultralytics import YOLO; model = YOLO('yolov8s.pt'); print('YOLOv8 small model downloaded')" +RUN python -c "from ultralytics import YOLO; model = YOLO('yolov8m.pt'); print('YOLOv8 medium model downloaded')" + +# Copy the current directory contents +COPY . . + +# Set environment variables for GUI +ENV QT_X11_NO_MITSHM=1 +ENV DISPLAY=:0 + +# Run the main GUI application +CMD ["python", "main.py"] diff --git a/data_engine/poc/README.md b/data_engine/poc/README.md new file mode 100644 index 0000000..a7b16ec --- /dev/null +++ b/data_engine/poc/README.md @@ -0,0 +1,30 @@ +# SAM2 Data Engine - YOLO Dataset Generator + +An automatic data engine GUI application that uses visual prompts to interact with SAM2 (Segment Anything Model 2) for generating YOLO training datasets with automatic video segmentation. + + + +This implementation is a POC of the actual data engine that we want to build. It should serve as reference as the type of features we want to have. + + +### How to run this code? + +It has a dockerfile to run the application. + +First give access for the docker access the xserver with this command: + +```bash +xhost + +``` + + +Then run docker container with compose. Inside the directory run the following command: + +```bash +docker compose up +``` + +You should see the following applicatation running. + + +![](poc_screenshot.png) diff --git a/data_engine/poc/config.py b/data_engine/poc/config.py new file mode 100644 index 0000000..b4e2940 --- /dev/null +++ b/data_engine/poc/config.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python3 +""" +Configuration management for the Data Engine +""" + +import json +import os +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Optional + + +@dataclass +class SAMConfig: + """SAM2 model configuration""" + + model_name: str = "sam2_b.pt" + device: str = "auto" + multimask_output: bool = False + cache_features: bool = True + batch_size: int = 1 + + +@dataclass +class UIConfig: + """UI configuration""" + + window_width: int = 1400 + window_height: int = 900 + frame_cache_size: int = 100 + auto_save_interval: int = 300 # seconds + default_point_size: int = 5 + overlay_alpha: float = 0.5 + + +@dataclass +class ExportConfig: + """Export configuration""" + + export_format: str = "yolo" # yolo, coco, voc + image_format: str = "jpg" + compression_quality: int = 95 + include_empty_frames: bool = False + validate_annotations: bool = True + + +@dataclass +class CacheConfig: + """Cache configuration""" + + max_cache_size_gb: float = 5.0 + auto_cleanup: bool = True + cleanup_interval_hours: int = 24 + keep_recent_projects: int = 5 + + +@dataclass +class DataEngineConfig: + """Main configuration class""" + + sam: SAMConfig + ui: UIConfig + export: ExportConfig + cache: CacheConfig + project_name: str = "untitled" + last_video_path: str = "" + last_output_dir: str = "" + classes: Dict[int, str] = None + + def __post_init__(self): + if self.classes is None: + self.classes = {} + + +class ConfigManager: + """Configuration manager for the Data Engine""" + + def __init__(self, config_dir: str = None): + if config_dir is None: + config_dir = Path.home() / ".data_engine" + + self.config_dir = Path(config_dir) + self.config_dir.mkdir(exist_ok=True) + + self.config_file = self.config_dir / "config.json" + self.projects_dir = self.config_dir / "projects" + self.projects_dir.mkdir(exist_ok=True) + + self._config = self._load_config() + + def _load_config(self) -> DataEngineConfig: + """Load configuration from file""" + if self.config_file.exists(): + try: + with open(self.config_file, "r") as f: + data = json.load(f) + + # Convert nested dictionaries back to dataclasses + sam_config = SAMConfig(**data.get("sam", {})) + ui_config = UIConfig(**data.get("ui", {})) + export_config = ExportConfig(**data.get("export", {})) + cache_config = CacheConfig(**data.get("cache", {})) + + return DataEngineConfig( + sam=sam_config, + ui=ui_config, + export=export_config, + cache=cache_config, + project_name=data.get("project_name", "untitled"), + last_video_path=data.get("last_video_path", ""), + last_output_dir=data.get("last_output_dir", ""), + classes=data.get("classes", {}), + ) + except Exception as e: + print(f"Error loading config: {e}, using defaults") + + # Return default configuration + return DataEngineConfig( + sam=SAMConfig(), ui=UIConfig(), export=ExportConfig(), cache=CacheConfig() + ) + + def save_config(self): + """Save configuration to file""" + try: + # Convert dataclasses to dictionaries + config_dict = { + "sam": asdict(self._config.sam), + "ui": asdict(self._config.ui), + "export": asdict(self._config.export), + "cache": asdict(self._config.cache), + "project_name": self._config.project_name, + "last_video_path": self._config.last_video_path, + "last_output_dir": self._config.last_output_dir, + "classes": self._config.classes, + } + + with open(self.config_file, "w") as f: + json.dump(config_dict, f, indent=2) + + except Exception as e: + print(f"Error saving config: {e}") + + @property + def config(self) -> DataEngineConfig: + """Get current configuration""" + return self._config + + def update_sam_config(self, **kwargs): + """Update SAM configuration""" + for key, value in kwargs.items(): + if hasattr(self._config.sam, key): + setattr(self._config.sam, key, value) + self.save_config() + + def update_ui_config(self, **kwargs): + """Update UI configuration""" + for key, value in kwargs.items(): + if hasattr(self._config.ui, key): + setattr(self._config.ui, key, value) + self.save_config() + + def update_export_config(self, **kwargs): + """Update export configuration""" + for key, value in kwargs.items(): + if hasattr(self._config.export, key): + setattr(self._config.export, key, value) + self.save_config() + + def update_cache_config(self, **kwargs): + """Update cache configuration""" + for key, value in kwargs.items(): + if hasattr(self._config.cache, key): + setattr(self._config.cache, key, value) + self.save_config() + + def set_last_paths(self, video_path: str = None, output_dir: str = None): + """Update last used paths""" + if video_path: + self._config.last_video_path = video_path + if output_dir: + self._config.last_output_dir = output_dir + self.save_config() + + def save_project(self, project_name: str, project_data: Dict[str, Any]): + """Save project data""" + project_file = self.projects_dir / f"{project_name}.json" + try: + with open(project_file, "w") as f: + json.dump(project_data, f, indent=2) + except Exception as e: + print(f"Error saving project: {e}") + + def load_project(self, project_name: str) -> Optional[Dict[str, Any]]: + """Load project data""" + project_file = self.projects_dir / f"{project_name}.json" + if project_file.exists(): + try: + with open(project_file, "r") as f: + return json.load(f) + except Exception as e: + print(f"Error loading project: {e}") + return None + + def list_projects(self) -> list: + """List available projects""" + projects = [] + for project_file in self.projects_dir.glob("*.json"): + projects.append(project_file.stem) + return sorted(projects) + + def delete_project(self, project_name: str) -> bool: + """Delete a project""" + project_file = self.projects_dir / f"{project_name}.json" + try: + if project_file.exists(): + project_file.unlink() + return True + except Exception as e: + print(f"Error deleting project: {e}") + return False + + def get_cache_dir(self) -> Path: + """Get cache directory""" + cache_dir = self.config_dir / "cache" + cache_dir.mkdir(exist_ok=True) + return cache_dir + + def cleanup_cache(self, max_size_gb: float = None): + """Clean up cache directory""" + if max_size_gb is None: + max_size_gb = self._config.cache.max_cache_size_gb + + cache_dir = self.get_cache_dir() + if not cache_dir.exists(): + return + + # Calculate current cache size + total_size = 0 + cache_files = [] + + for file_path in cache_dir.rglob("*"): + if file_path.is_file(): + size = file_path.stat().st_size + total_size += size + cache_files.append((file_path, size, file_path.stat().st_mtime)) + + # Convert to GB + total_size_gb = total_size / (1024**3) + + if total_size_gb > max_size_gb: + # Sort by modification time (oldest first) + cache_files.sort(key=lambda x: x[2]) + + # Remove oldest files until under limit + target_size = max_size_gb * 0.8 # Remove to 80% of limit + current_size_gb = total_size_gb + + for file_path, size, _ in cache_files: + if current_size_gb <= target_size: + break + try: + file_path.unlink() + current_size_gb -= size / (1024**3) + except OSError: + pass + + def get_model_cache_dir(self) -> Path: + """Get model cache directory""" + model_dir = self.get_cache_dir() / "models" + model_dir.mkdir(exist_ok=True) + return model_dir + + def export_config(self, export_path: str): + """Export configuration to a file""" + try: + config_dict = { + "sam": asdict(self._config.sam), + "ui": asdict(self._config.ui), + "export": asdict(self._config.export), + "cache": asdict(self._config.cache), + "classes": self._config.classes, + } + + with open(export_path, "w") as f: + json.dump(config_dict, f, indent=2) + return True + except Exception as e: + print(f"Error exporting config: {e}") + return False + + def import_config(self, import_path: str) -> bool: + """Import configuration from a file""" + try: + with open(import_path, "r") as f: + config_dict = json.load(f) + + # Update configuration + if "sam" in config_dict: + self.update_sam_config(**config_dict["sam"]) + if "ui" in config_dict: + self.update_ui_config(**config_dict["ui"]) + if "export" in config_dict: + self.update_export_config(**config_dict["export"]) + if "cache" in config_dict: + self.update_cache_config(**config_dict["cache"]) + if "classes" in config_dict: + self._config.classes = config_dict["classes"] + self.save_config() + + return True + except Exception as e: + print(f"Error importing config: {e}") + return False + + +# Global configuration manager instance +_config_manager = None + + +def get_config_manager() -> ConfigManager: + """Get global configuration manager""" + global _config_manager + if _config_manager is None: + _config_manager = ConfigManager() + return _config_manager + + +def get_config() -> DataEngineConfig: + """Get current configuration""" + return get_config_manager().config diff --git a/data_engine/poc/docker-compose.yml b/data_engine/poc/docker-compose.yml new file mode 100644 index 0000000..ba15548 --- /dev/null +++ b/data_engine/poc/docker-compose.yml @@ -0,0 +1,36 @@ +services: + sam2-data-engine: + build: + context: . + dockerfile: Dockerfile + image: sam2-data-engine:latest + container_name: sam2_data_engine + + # GPU support + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + + # X11 Display support + environment: + - DISPLAY=${DISPLAY} + - QT_X11_NO_MITSHM=1 + - XAUTHORITY=/tmp/.docker.xauth + - YOLO_CONFIG_DIR=/app/.config + + # Mount X11 socket and auth + volumes: + - /tmp/.X11-unix:/tmp/.X11-unix:rw + - ./data:/app/data:rw + - ./cache:/app/cache:rw + - ./projects:/app/projects:rw + - ./exports:/app/exports:rw + - ./videos:/app/videos:ro + - ../yolo_models:/app/yolo_models:rw + + privileged: True + network_mode: host diff --git a/data_engine/poc/gui/README.md b/data_engine/poc/gui/README.md new file mode 100644 index 0000000..de9fbc5 --- /dev/null +++ b/data_engine/poc/gui/README.md @@ -0,0 +1,75 @@ +# GUI Module + +This directory contains the modularized GUI components for the SAM2 Data Engine application. + +## Structure + +- `__init__.py` - Module initialization and exports +- `main_window.py` - Main application window and core logic +- `frame_viewer.py` - Interactive frame viewer widget with point selection +- `object_class_manager.py` - Widget for managing object classes +- `left_panel.py` - Left control panel with video loading, navigation, and SAM controls +- `right_panel.py` - Right panel with class management and annotations display +- `center_panel.py` - Center panel wrapper for the frame viewer +- `workers.py` - Background worker threads for video processing and SAM operations + +## Components Overview + +### DataEngineMainWindow +The main application window that orchestrates all other components. Handles: +- Video loading and frame caching +- SAM model initialization +- Annotation management +- YOLO dataset export + +### FrameViewer +Custom widget for displaying video frames with interactive point selection: +- Click-based point annotation (positive/negative) +- Visual feedback for selected points +- Frame scaling and coordinate mapping + +### ObjectClassManager +Widget for managing object classes in the dataset: +- Add/remove classes +- Class selection for annotation +- Dynamic class ID assignment + +### Control Panels +- **LeftControlPanel**: Video controls, navigation, SAM settings, export +- **RightPanel**: Class management, annotations list, statistics +- **CenterPanel**: Simple wrapper for the frame viewer + +### Workers +- **VideoProcessor**: Background thread for frame extraction and caching +- **SAMProcessor**: Background worker for SAM2 segmentation processing + +## Usage + +```python +from gui import DataEngineMainWindow +from PySide6.QtWidgets import QApplication +import sys + +app = QApplication(sys.argv) +window = DataEngineMainWindow() +window.show() +sys.exit(app.exec()) +``` + +## Dependencies + +- PySide6 (Qt GUI framework) +- OpenCV (cv2) for image processing +- NumPy for array operations +- PyTorch for tensor operations +- Native SAM2 for segmentation +- pathlib for path handling + +## Architecture + +The GUI follows a modular design pattern where: +1. Each major UI component is in its own file +2. The main window coordinates between components +3. Worker threads handle heavy processing +4. Signal/slot connections provide component communication +5. Clean separation between UI and business logic diff --git a/data_engine/poc/gui/__init__.py b/data_engine/poc/gui/__init__.py new file mode 100644 index 0000000..fb9c470 --- /dev/null +++ b/data_engine/poc/gui/__init__.py @@ -0,0 +1,23 @@ +""" +GUI module for SAM2 Data Engine +Contains modularized GUI components +""" + +from .center_panel import CenterPanel +from .frame_viewer import FrameViewer +from .left_panel import LeftControlPanel +from .main_window import DataEngineMainWindow +from .object_class_manager import ObjectClassManager +from .right_panel import RightPanel +from .workers import SAMProcessor, VideoProcessor + +__all__ = [ + "DataEngineMainWindow", + "FrameViewer", + "ObjectClassManager", + "LeftControlPanel", + "RightPanel", + "CenterPanel", + "VideoProcessor", + "SAMProcessor", +] diff --git a/data_engine/poc/gui/center_panel.py b/data_engine/poc/gui/center_panel.py new file mode 100644 index 0000000..a196ff5 --- /dev/null +++ b/data_engine/poc/gui/center_panel.py @@ -0,0 +1,25 @@ +""" +Center panel containing the frame viewer +""" + +from PySide6.QtWidgets import QVBoxLayout, QWidget + +from .frame_viewer import FrameViewer + + +class CenterPanel(QWidget): + """Center panel for frame viewing""" + + def __init__(self, parent=None): + super().__init__(parent) + self.setup_ui() + + def setup_ui(self): + """Setup the center panel UI""" + layout = QVBoxLayout() + + # Frame viewer + self.frame_viewer = FrameViewer() + + layout.addWidget(self.frame_viewer) + self.setLayout(layout) diff --git a/data_engine/poc/gui/frame_viewer.py b/data_engine/poc/gui/frame_viewer.py new file mode 100644 index 0000000..e62af62 --- /dev/null +++ b/data_engine/poc/gui/frame_viewer.py @@ -0,0 +1,236 @@ +""" +Frame viewer widget for displaying video frames with interactive point selection +""" + +import cv2 +import numpy as np +from PySide6.QtCore import Qt, Signal +from PySide6.QtGui import QImage, QPixmap +from PySide6.QtWidgets import QLabel + + +class FrameViewer(QLabel): + """Custom QLabel for displaying frames with click interaction""" + + point_clicked = Signal( + int, int, int + ) # x, y, label (1 for positive, 0 for negative) + + def __init__(self): + super().__init__() + self.setMinimumSize(640, 480) + self.setStyleSheet("border: 2px solid gray;") + self.setScaledContents(False) # Don't stretch - maintain aspect ratio + self.setAlignment(Qt.AlignCenter) # Center the image + self.current_frame = None + self.scale_factor = 1.0 + self.click_mode = 1 # 1 for positive, 0 for negative + self.points = [] # List of (x, y, label) tuples + self.masks = [] # List of masks to overlay + + def set_frame(self, frame: np.ndarray): + """Set the current frame to display""" + self.current_frame = frame.copy() + self.update_display() + + def add_mask( + self, mask: np.ndarray, color: tuple = (0, 255, 0), alpha: float = 0.3 + ): + """Add a segmentation mask to display""" + self.masks.append({"mask": mask, "color": color, "alpha": alpha}) + self.update_display() + + def clear_masks(self): + """Clear all segmentation masks""" + self.masks.clear() + self.update_display() + + def set_click_mode(self, mode: int): + """Set click mode: 1 for positive points, 0 for negative points""" + self.click_mode = mode + + def add_point(self, x: int, y: int, label: int): + """Add a point to the current frame""" + self.points.append((x, y, label)) + self.update_display() + + def clear_points(self): + """Clear all points""" + self.points.clear() + self.update_display() + + def update_display(self): + """Update the display with current frame, masks, and points""" + if self.current_frame is None: + return + + display_frame = self.current_frame.copy() + + # Draw segmentation masks + for mask_info in self.masks: + mask = mask_info["mask"] + color = mask_info["color"] + alpha = mask_info["alpha"] + + # Validate mask + try: + if mask is None: + print("Warning: None mask, skipping") + continue + + # Convert to numpy array if needed + if not isinstance(mask, np.ndarray): + mask = np.array(mask) + + if mask.size == 0: + print("Warning: Empty mask, skipping") + continue + + # Ensure mask is 2D + if mask.ndim > 2: + mask = mask.squeeze() + if mask.ndim != 2: + print( + f"Warning: Mask has invalid dimensions {mask.shape}, skipping" + ) + continue + + # Ensure mask has valid dimensions + if mask.shape[0] == 0 or mask.shape[1] == 0: + print(f"Warning: Mask has zero dimensions {mask.shape}, skipping") + continue + + # Ensure we have valid frame dimensions + if display_frame.shape[0] <= 0 or display_frame.shape[1] <= 0: + continue + + # Resize mask to match frame if needed + if mask.shape[:2] != display_frame.shape[:2]: + # Ensure mask is uint8 for resize + if mask.dtype == np.bool_: + mask_for_resize = mask.astype(np.uint8) * 255 + elif mask.dtype in [np.float32, np.float64]: + mask_for_resize = (mask * 255).astype(np.uint8) + else: + mask_for_resize = mask.astype(np.uint8) + + # Check target dimensions are valid + target_height, target_width = display_frame.shape[:2] + if target_height <= 0 or target_width <= 0: + print( + f"Warning: Invalid target dimensions {target_height}x{target_width}, skipping" + ) + continue + + mask = cv2.resize( + mask_for_resize, + (target_width, target_height), + interpolation=cv2.INTER_NEAREST, + ) + + # Ensure final mask is binary + if mask.max() > 1: + mask_bool = mask > 127 + else: + mask_bool = mask > 0.5 + + except Exception as e: + print(f"Warning: Error processing mask: {e}") + continue + + # Create colored overlay + overlay = display_frame.copy() + try: + overlay[mask_bool] = color + except Exception as e: + print(f"Warning: Failed to apply mask overlay: {e}") + continue + + # Blend with original frame + try: + display_frame = cv2.addWeighted( + display_frame, 1 - alpha, overlay, alpha, 0 + ) + except Exception as e: + print(f"Warning: Failed to blend overlay: {e}") + continue + + # Add mask contours + try: + # Convert mask to uint8 for contour detection + if mask.max() > 1: + mask_uint8 = mask.astype(np.uint8) + else: + mask_uint8 = (mask_bool * 255).astype(np.uint8) + + contours, _ = cv2.findContours( + mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + cv2.drawContours(display_frame, contours, -1, color, 2) + except Exception as e: + print(f"Warning: Failed to draw contours: {e}") + continue + + # Draw points + for x, y, label in self.points: + color = ( + (0, 255, 0) if label == 1 else (0, 0, 255) + ) # Green for positive, Red for negative + cv2.circle(display_frame, (x, y), 5, color, -1) + cv2.circle(display_frame, (x, y), 7, (255, 255, 255), 2) + + # Convert to QPixmap with proper aspect ratio + height, width, channel = display_frame.shape + bytes_per_line = 3 * width + q_image = QImage( + display_frame.data, width, height, bytes_per_line, QImage.Format_RGB888 + ).rgbSwapped() + + # Scale pixmap to fit widget while maintaining aspect ratio + pixmap = QPixmap.fromImage(q_image) + widget_size = self.size() + scaled_pixmap = pixmap.scaled( + widget_size, Qt.KeepAspectRatio, Qt.SmoothTransformation + ) + + self.setPixmap(scaled_pixmap) + + def mousePressEvent(self, event): + """Handle mouse click events""" + if self.current_frame is None or self.pixmap() is None: + return + + # Get click position relative to the actual image + widget_size = self.size() + pixmap_size = self.pixmap().size() + + # Calculate the actual displayed image position and size + # (since we use Qt.KeepAspectRatio and center alignment) + scale_x = pixmap_size.width() / self.current_frame.shape[1] + scale_y = pixmap_size.height() / self.current_frame.shape[0] + scale = min(scale_x, scale_y) + + # Calculate the displayed image size + display_width = int(self.current_frame.shape[1] * scale) + display_height = int(self.current_frame.shape[0] * scale) + + # Calculate offset (centering) + offset_x = (widget_size.width() - display_width) // 2 + offset_y = (widget_size.height() - display_height) // 2 + + # Convert widget coordinates to image coordinates + click_x = event.position().x() - offset_x + click_y = event.position().y() - offset_y + + # Check if click is within the image bounds + if 0 <= click_x <= display_width and 0 <= click_y <= display_height: + # Convert to original image coordinates + x = int(click_x / scale) + y = int(click_y / scale) + + # Ensure coordinates are within bounds + x = max(0, min(x, self.current_frame.shape[1] - 1)) + y = max(0, min(y, self.current_frame.shape[0] - 1)) + + self.add_point(x, y, self.click_mode) + self.point_clicked.emit(x, y, self.click_mode) diff --git a/data_engine/poc/gui/left_panel.py b/data_engine/poc/gui/left_panel.py new file mode 100644 index 0000000..bce649b --- /dev/null +++ b/data_engine/poc/gui/left_panel.py @@ -0,0 +1,207 @@ +""" +Left control panel for video loading, navigation, and SAM controls +""" + +from pathlib import Path + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QCheckBox, + QGroupBox, + QHBoxLayout, + QLabel, + QListWidget, + QProgressBar, + QPushButton, + QSlider, + QSpinBox, + QVBoxLayout, + QWidget, +) + + +class LeftControlPanel(QWidget): + """Left panel containing video loading, navigation, and SAM controls""" + + def __init__(self, parent=None): + super().__init__(parent) + self.setup_ui() + + def setup_ui(self): + """Setup the left panel UI""" + layout = QVBoxLayout() + + # Video loading group + video_group = self.create_video_group() + + # Frame navigation group + nav_group = self.create_navigation_group() + + # SAM controls group + sam_group = self.create_sam_group() + + # SAM2 Video Tracking group + sam2_group = self.create_sam2_group() + + # Export group + export_group = self.create_export_group() + + # Add all groups to layout + layout.addWidget(video_group) + layout.addWidget(nav_group) + layout.addWidget(sam_group) + layout.addWidget(sam2_group) + layout.addWidget(export_group) + layout.addStretch() + + self.setLayout(layout) + + def create_video_group(self) -> QGroupBox: + """Create video loading group""" + video_group = QGroupBox("Video Loading") + video_layout = QVBoxLayout() + + self.load_video_btn = QPushButton("Load Video") + self.video_path_label = QLabel("No video loaded") + self.video_path_label.setWordWrap(True) + + self.output_dir_btn = QPushButton("Set Output Directory") + self.output_dir_label = QLabel("No output directory set") + self.output_dir_label.setWordWrap(True) + + video_layout.addWidget(self.load_video_btn) + video_layout.addWidget(self.video_path_label) + video_layout.addWidget(self.output_dir_btn) + video_layout.addWidget(self.output_dir_label) + video_group.setLayout(video_layout) + + return video_group + + def create_navigation_group(self) -> QGroupBox: + """Create frame navigation group""" + nav_group = QGroupBox("Frame Navigation") + nav_layout = QVBoxLayout() + + # Frame slider + self.frame_slider = QSlider(Qt.Horizontal) + + # Frame controls + frame_controls = QHBoxLayout() + self.prev_btn = QPushButton("◀ Prev") + self.frame_spinbox = QSpinBox() + self.next_btn = QPushButton("Next ▶") + + frame_controls.addWidget(self.prev_btn) + frame_controls.addWidget(self.frame_spinbox) + frame_controls.addWidget(self.next_btn) + + nav_layout.addWidget(QLabel("Frame:")) + nav_layout.addWidget(self.frame_slider) + nav_layout.addLayout(frame_controls) + nav_group.setLayout(nav_layout) + + return nav_group + + def create_sam_group(self) -> QGroupBox: + """Create SAM2 point controls group""" + sam_group = QGroupBox("SAM2 Point Controls") + sam_layout = QVBoxLayout() + + # Point selection mode + point_mode_layout = QHBoxLayout() + self.positive_point_btn = QPushButton("Positive Point") + self.positive_point_btn.setCheckable(True) + self.positive_point_btn.setChecked(True) + + self.negative_point_btn = QPushButton("Negative Point") + self.negative_point_btn.setCheckable(True) + + point_mode_layout.addWidget(self.positive_point_btn) + point_mode_layout.addWidget(self.negative_point_btn) + + # SAM action buttons + self.clear_points_btn = QPushButton("Clear Points") + self.segment_btn = QPushButton("Segment Object") + + sam_layout.addLayout(point_mode_layout) + sam_layout.addWidget(self.clear_points_btn) + sam_layout.addWidget(self.segment_btn) + sam_group.setLayout(sam_layout) + + return sam_group + + def create_sam2_group(self) -> QGroupBox: + """Create SAM2 Video Tracking group""" + sam2_group = QGroupBox("SAM2 Video Tracking") + sam2_layout = QVBoxLayout() + + # Video initialization + self.init_video_btn = QPushButton("Initialize Video Tracking") + self.init_video_btn.setEnabled(False) + + # Object management + object_controls = QHBoxLayout() + self.add_object_btn = QPushButton("Add Object") + self.add_object_btn.setEnabled(False) + self.remove_object_btn = QPushButton("Remove Object") + self.remove_object_btn.setEnabled(False) + + object_controls.addWidget(self.add_object_btn) + object_controls.addWidget(self.remove_object_btn) + + # Tracked objects list + self.tracked_objects_list = QListWidget() + self.tracked_objects_list.setMaximumHeight(100) + + # Propagation controls + prop_controls = QHBoxLayout() + self.propagate_forward_btn = QPushButton("Propagate Forward") + self.propagate_forward_btn.setEnabled(False) + self.propagate_backward_btn = QPushButton("Propagate Backward") + self.propagate_backward_btn.setEnabled(False) + + prop_controls.addWidget(self.propagate_forward_btn) + prop_controls.addWidget(self.propagate_backward_btn) + + # Clear tracking + self.clear_tracking_btn = QPushButton("Clear All Tracking") + self.clear_tracking_btn.setEnabled(False) + + # Auto-propagation option + self.auto_propagate_checkbox = QCheckBox("Auto-propagate on frame change") + + # Progress bar for tracking operations + self.tracking_progress = QProgressBar() + self.tracking_progress.setVisible(False) + + # Status label + self.tracking_status_label = QLabel("Video tracking not initialized") + self.tracking_status_label.setWordWrap(True) + + # Add all controls to layout + sam2_layout.addWidget(self.init_video_btn) + sam2_layout.addWidget(QLabel("Tracked Objects:")) + sam2_layout.addWidget(self.tracked_objects_list) + sam2_layout.addLayout(object_controls) + sam2_layout.addLayout(prop_controls) + sam2_layout.addWidget(self.clear_tracking_btn) + sam2_layout.addWidget(self.auto_propagate_checkbox) + sam2_layout.addWidget(self.tracking_progress) + sam2_layout.addWidget(self.tracking_status_label) + + sam2_group.setLayout(sam2_layout) + return sam2_group + + def create_export_group(self) -> QGroupBox: + """Create export group""" + export_group = QGroupBox("Export") + export_layout = QVBoxLayout() + + self.export_yolo_btn = QPushButton("Export YOLO Dataset") + self.progress_bar = QProgressBar() + + export_layout.addWidget(self.export_yolo_btn) + export_layout.addWidget(self.progress_bar) + export_group.setLayout(export_layout) + + return export_group diff --git a/data_engine/poc/gui/main_window.py b/data_engine/poc/gui/main_window.py new file mode 100644 index 0000000..266e484 --- /dev/null +++ b/data_engine/poc/gui/main_window.py @@ -0,0 +1,786 @@ +""" +Main window for the SAM2 Data Engine application +""" + +import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import cv2 +import numpy as np +import torch +from PySide6.QtCore import Qt, QThreadPool +from PySide6.QtWidgets import ( + QApplication, + QFileDialog, + QHBoxLayout, + QMainWindow, + QMessageBox, + QSplitter, + QWidget, +) + +# Import local modules +from .frame_viewer import FrameViewer +from .left_panel import LeftControlPanel +from .right_panel import RightPanel +from .sam2_video_worker import SAM2VideoWorker +from .workers import VideoProcessor + + +class DataEngineMainWindow(QMainWindow): + """Main window for the Data Engine application""" + + def __init__(self): + super().__init__() + self.setWindowTitle("SAM2 Data Engine - YOLO Dataset Generator") + self.setGeometry(100, 100, 1400, 900) + + # Initialize variables + self.video_path = None + self.output_dir = None + self.frames = [] + self.current_frame_idx = 0 + self.frame_cache = {} + self.masks_cache = {} + self.annotations = {} # frame_idx: [annotations] + + # Thread pool for SAM processing + self.thread_pool = QThreadPool() + + # SAM2 Video Worker for advanced tracking + self.sam2_worker = SAM2VideoWorker() + self.video_tracking_enabled = False + self.tracked_objects = {} # obj_id -> object_info + self.next_object_id = 0 + self.propagation_results = {} # frame_idx -> {obj_id: mask} + + self.setup_ui() + self.connect_signals() + self.connect_sam2_signals() + + def setup_ui(self): + """Setup the user interface""" + central_widget = QWidget() + self.setCentralWidget(central_widget) + + # Main layout + main_layout = QHBoxLayout() + + # Create panels + self.left_panel = LeftControlPanel() + self.frame_viewer = FrameViewer() + self.right_panel = RightPanel() + + # Create splitter for resizable panels + splitter = QSplitter(Qt.Horizontal) + splitter.addWidget(self.left_panel) + splitter.addWidget(self.frame_viewer) + splitter.addWidget(self.right_panel) + splitter.setSizes([300, 700, 300]) + + main_layout.addWidget(splitter) + central_widget.setLayout(main_layout) + + def connect_signals(self): + """Connect signals between components""" + # Left panel signals + self.left_panel.load_video_btn.clicked.connect(self.load_video) + self.left_panel.output_dir_btn.clicked.connect(self.set_output_directory) + + # Frame navigation + self.left_panel.frame_slider.valueChanged.connect(self.on_frame_changed) + self.left_panel.frame_spinbox.valueChanged.connect( + self.on_frame_spinbox_changed + ) + self.left_panel.prev_btn.clicked.connect(self.prev_frame) + self.left_panel.next_btn.clicked.connect(self.next_frame) + + # SAM controls (single frame) + self.left_panel.positive_point_btn.clicked.connect( + lambda: self.set_point_mode(1) + ) + self.left_panel.negative_point_btn.clicked.connect( + lambda: self.set_point_mode(0) + ) + self.left_panel.clear_points_btn.clicked.connect(self.clear_points) + self.left_panel.segment_btn.clicked.connect(self.segment_current_frame) + + # SAM2 Video Tracking controls + self.left_panel.init_video_btn.clicked.connect(self.init_video_tracking) + self.left_panel.add_object_btn.clicked.connect(self.add_tracked_object) + self.left_panel.remove_object_btn.clicked.connect(self.remove_tracked_object) + self.left_panel.propagate_forward_btn.clicked.connect( + self.propagate_forward_sam2 + ) + self.left_panel.propagate_backward_btn.clicked.connect( + self.propagate_backward_sam2 + ) + self.left_panel.clear_tracking_btn.clicked.connect(self.clear_all_tracking) + self.left_panel.tracked_objects_list.itemSelectionChanged.connect( + self.on_object_selection_changed + ) + + # Export + self.left_panel.export_yolo_btn.clicked.connect(self.export_yolo_dataset) + + # Frame viewer signals + self.frame_viewer.point_clicked.connect(self.on_point_clicked) + + # Right panel signals + self.right_panel.remove_annotation_btn.clicked.connect(self.remove_annotation) + + def connect_sam2_signals(self): + """Connect SAM2 worker signals""" + self.sam2_worker.model_loaded.connect(self.on_sam2_model_loaded) + self.sam2_worker.video_initialized.connect(self.on_video_initialized) + self.sam2_worker.object_added.connect(self.on_object_added) + self.sam2_worker.propagation_progress.connect(self.on_propagation_progress) + self.sam2_worker.propagation_complete.connect(self.on_propagation_complete) + self.sam2_worker.error_occurred.connect(self.on_sam2_error) + + def load_sam_model(self): + """Load SAM2 video model""" + try: + # Load SAM2 video model for tracking + # TODO: Make this be loadable by the user. Let the user select the model it wants to use. + self.sam2_worker.load_model("sam2_hiera_base_plus.pt") + print("SAM2 video model loaded successfully") + except Exception as e: + print(f"Warning: Failed to load SAM2 model: {e}") + QMessageBox.warning( + self, + "Model Loading Warning", + f"SAM2 model failed to load: {e}\nSome features may be limited.", + ) + + def load_video(self): + """Load a video file""" + file_path, _ = QFileDialog.getOpenFileName( + self, + "Select Video File", + "", + "Video Files (*.mp4 *.avi *.mov *.mkv *.wmv);;All Files (*)", + ) + + if file_path: + self.video_path = file_path + self.left_panel.video_path_label.setText(f"Video: {Path(file_path).name}") + self.load_video_frames() + + # Enable SAM2 video tracking controls + self.left_panel.init_video_btn.setEnabled(True) + self.left_panel.tracking_status_label.setText( + "Video loaded. Click 'Initialize Video Tracking' to begin." + ) + + def set_output_directory(self): + """Set the output directory for the dataset""" + # TODO: This should have a default path, like the current working directory. + dir_path = QFileDialog.getExistingDirectory(self, "Select Output Directory") + + if dir_path: + self.output_dir = dir_path + self.left_panel.output_dir_label.setText(f"Output: {Path(dir_path).name}") + + def load_video_frames(self): + """Load frames from the video""" + if not self.video_path: + return + + try: + cap = cv2.VideoCapture(self.video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + # Setup frame navigation + self.left_panel.frame_slider.setMaximum(total_frames - 1) + self.left_panel.frame_spinbox.setMaximum(total_frames - 1) + + # Cache first frame for immediate display + cap.set(cv2.CAP_PROP_POS_FRAMES, 0) + ret, frame = cap.read() + if ret: + self.frame_cache[0] = frame + self.frame_viewer.set_frame(frame) + + cap.release() + + # Start background frame caching + if self.output_dir: + self.cache_frames_async() + + self.update_stats() + + except Exception as e: + QMessageBox.critical(self, "Error", f"Failed to load video: {e}") + + def cache_frames_async(self): + """Cache frames asynchronously""" + cache_dir = Path(self.output_dir) / "cache" / "frames" + cache_dir.mkdir(parents=True, exist_ok=True) + + # Process video frames in background + self.video_processor = VideoProcessor(self.video_path, str(cache_dir)) + self.video_processor.frame_processed.connect(self.on_frame_cached) + self.video_processor.progress_updated.connect( + self.left_panel.progress_bar.setValue + ) + self.video_processor.start() + + def on_frame_cached(self, frame_idx: int, frame: np.ndarray): + """Handle cached frame""" + self.frame_cache[frame_idx] = frame + + def on_frame_changed(self, frame_idx: int): + """Handle frame slider change""" + self.current_frame_idx = frame_idx + self.left_panel.frame_spinbox.setValue(frame_idx) + self.load_current_frame() + self.update_annotations_display() + + def on_frame_spinbox_changed(self, frame_idx: int): + """Handle frame spinbox change""" + self.left_panel.frame_slider.setValue(frame_idx) + + def load_current_frame(self): + """Load and display the current frame""" + if self.current_frame_idx in self.frame_cache: + frame = self.frame_cache[self.current_frame_idx] + self.frame_viewer.set_frame(frame) + + # Clear existing masks and load masks for this frame + self.frame_viewer.clear_masks() + + # Show manual annotations + if self.current_frame_idx in self.annotations: + for annotation in self.annotations[self.current_frame_idx]: + mask = annotation["mask"] + class_id = annotation["class_id"] + color = self.get_class_color(class_id) + self.frame_viewer.add_mask(mask, color=color, alpha=0.4) + + # Show SAM2 tracking results + if self.current_frame_idx in self.propagation_results: + for obj_id, mask in self.propagation_results[ + self.current_frame_idx + ].items(): + if obj_id in self.tracked_objects: + class_id = self.tracked_objects[obj_id]["class_id"] + color = self.get_class_color(class_id) + self.frame_viewer.add_mask(mask, color=color, alpha=0.3) + + # Show individual tracked object masks + for obj_id, obj_info in self.tracked_objects.items(): + if self.current_frame_idx in obj_info["masks"]: + mask = obj_info["masks"][self.current_frame_idx] + class_id = obj_info["class_id"] + color = self.get_class_color(class_id) + self.frame_viewer.add_mask(mask, color=color, alpha=0.3) + else: + # Load frame from video if not cached + self.load_frame_from_video(self.current_frame_idx) + + def load_frame_from_video(self, frame_idx: int): + """Load a specific frame from video""" + if not self.video_path: + return + + try: + cap = cv2.VideoCapture(self.video_path) + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) + ret, frame = cap.read() + cap.release() + + if ret: + self.frame_cache[frame_idx] = frame + self.frame_viewer.set_frame(frame) + except Exception as e: + print(f"Error loading frame {frame_idx}: {e}") + + def export_yolo_dataset(self): + """Export annotations in YOLO format""" + if not self.output_dir: + QMessageBox.warning(self, "Error", "Please set output directory first") + return + + if not self.annotations and not self.tracked_objects: + QMessageBox.warning(self, "Error", "No annotations to export") + return + + try: + self.create_yolo_dataset() + QMessageBox.information( + self, "Success", "YOLO dataset exported successfully" + ) + except Exception as e: + QMessageBox.critical(self, "Export Error", f"Failed to export dataset: {e}") + + def create_yolo_dataset(self): + """Create YOLO format dataset""" + if not self.output_dir: + raise ValueError("Output directory not set") + + dataset_dir = Path(self.output_dir) / "yolo_dataset" + images_dir = dataset_dir / "images" + labels_dir = dataset_dir / "labels" + + # Create directories + images_dir.mkdir(parents=True, exist_ok=True) + labels_dir.mkdir(parents=True, exist_ok=True) + + # Export frames and create YOLO annotations + all_frames = set(self.annotations.keys()) + all_frames.update( + self.tracked_objects.get(obj_id, {}).get("masks", {}).keys() + for obj_id in self.tracked_objects + ) + + for frame_idx in all_frames: + if frame_idx in self.frame_cache: + frame = self.frame_cache[frame_idx] + + # Save image + image_filename = f"frame_{frame_idx:06d}.jpg" + image_path = images_dir / image_filename + cv2.imwrite(str(image_path), frame) + + # Create YOLO label file + label_filename = f"frame_{frame_idx:06d}.txt" + label_path = labels_dir / label_filename + + with open(label_path, "w") as f: + # Write manual annotations + if frame_idx in self.annotations: + for annotation in self.annotations[frame_idx]: + mask = annotation["mask"] + class_id = annotation["class_id"] + bbox = self.mask_to_yolo_bbox(mask, frame.shape[:2]) + f.write(f"{class_id} {' '.join(map(str, bbox))}\n") + + # Write tracked object annotations + for obj_id, obj_info in self.tracked_objects.items(): + if frame_idx in obj_info["masks"]: + mask = obj_info["masks"][frame_idx] + class_id = obj_info["class_id"] + bbox = self.mask_to_yolo_bbox(mask, frame.shape[:2]) + f.write(f"{class_id} {' '.join(map(str, bbox))}\n") + + # Create classes.txt file + classes_path = dataset_dir / "classes.txt" + with open(classes_path, "w") as f: + for class_id, class_name in self.right_panel.class_manager.classes.items(): + f.write(f"{class_name}\n") + + print(f"YOLO dataset exported to {dataset_dir}") + + def mask_to_yolo_bbox(self, mask: np.ndarray, image_shape: tuple) -> list: + """Convert mask to YOLO format bounding box (normalized)""" + # Find bounding box of mask + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + + if not rows.any() or not cols.any(): + return [0.5, 0.5, 0.0, 0.0] # Default small box in center + + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + + # Convert to YOLO format (center_x, center_y, width, height) - normalized + h, w = image_shape + center_x = (cmin + cmax) / 2 / w + center_y = (rmin + rmax) / 2 / h + width = (cmax - cmin) / w + height = (rmax - rmin) / h + + return [center_x, center_y, width, height] + + def prev_frame(self): + """Go to previous frame""" + if self.current_frame_idx > 0: + self.left_panel.frame_slider.setValue(self.current_frame_idx - 1) + + def next_frame(self): + """Go to next frame""" + max_frame = self.left_panel.frame_slider.maximum() + if self.current_frame_idx < max_frame: + self.left_panel.frame_slider.setValue(self.current_frame_idx + 1) + + def set_point_mode(self, mode: int): + """Set point mode (0=negative, 1=positive)""" + self.frame_viewer.point_mode = mode + + # Update button states + self.left_panel.positive_point_btn.setChecked(mode == 1) + self.left_panel.negative_point_btn.setChecked(mode == 0) + + def clear_points(self): + """Clear all points on current frame""" + self.frame_viewer.clear_points() + + def segment_current_frame(self): + """Segment current frame using SAM2""" + if not self.frame_viewer.points: + QMessageBox.warning( + self, "Error", "Please click points on the object first" + ) + return + + if self.sam2_worker.predictor is None: + QMessageBox.warning(self, "Error", "SAM2 model not loaded") + return + + current_class_id = self.right_panel.class_manager.get_current_class_id() + if current_class_id is None: + QMessageBox.warning(self, "Error", "Please select a class") + return + + try: + # Prepare points for SAM2 + points = np.array([[x, y] for x, y, _ in self.frame_viewer.points]) + labels = np.array([label for _, _, label in self.frame_viewer.points]) + + # Use SAM2 video worker for single frame segmentation + # Create a unique object ID for this segmentation + obj_id = len(self.annotations.get(self.current_frame_idx, [])) + 1 + + # Add object to SAM2 for tracking + self.sam2_worker.add_object(self.current_frame_idx, points, labels, obj_id) + + # Note: The actual mask will be received via the object_added signal + # and processed in the corresponding slot method + + except Exception as e: + QMessageBox.critical(self, "Error", f"Segmentation failed: {e}") + + def store_annotation(self, frame_idx: int, mask: np.ndarray, class_id: int): + """Store annotation for a frame""" + if frame_idx not in self.annotations: + self.annotations[frame_idx] = [] + + annotation = { + "mask": mask, + "class_id": class_id, + "class_name": self.right_panel.class_manager.classes.get( + class_id, "unknown" + ), + } + + self.annotations[frame_idx].append(annotation) + + def update_annotations_display(self): + """Update the annotations list for current frame""" + self.right_panel.annotations_list.clear() + + if self.current_frame_idx in self.annotations: + for i, annotation in enumerate(self.annotations[self.current_frame_idx]): + class_name = annotation["class_name"] + item_text = f"Object {i}: {class_name}" + self.right_panel.annotations_list.addItem(item_text) + + def remove_annotation(self): + """Remove selected annotation""" + current_item = self.right_panel.annotations_list.currentItem() + if current_item and self.current_frame_idx in self.annotations: + row = self.right_panel.annotations_list.row(current_item) + if 0 <= row < len(self.annotations[self.current_frame_idx]): + del self.annotations[self.current_frame_idx][row] + self.update_annotations_display() + self.load_current_frame() # Refresh display + + def get_class_color(self, class_id: int) -> tuple: + """Get color for a class ID""" + colors = [ + (255, 0, 0), # Red + (0, 255, 0), # Green + (0, 0, 255), # Blue + (255, 255, 0), # Yellow + (255, 0, 255), # Magenta + (0, 255, 255), # Cyan + (255, 128, 0), # Orange + (128, 0, 255), # Purple + ] + return colors[class_id % len(colors)] + + def update_stats(self): + """Update statistics display""" + total_frames = ( + self.left_panel.frame_slider.maximum() + 1 if self.video_path else 0 + ) + annotated_frames = len(self.annotations) + tracked_frames = len( + set().union( + *( + obj_info["masks"].keys() + for obj_info in self.tracked_objects.values() + ) + ) + ) + + stats_text = f"Total: {total_frames} frames | Annotated: {annotated_frames} | Tracked: {tracked_frames}" + # You can display this in a status bar or label if needed + print(stats_text) + + def init_video_tracking(self): + """Initialize video for SAM2 tracking""" + if not self.video_path: + QMessageBox.warning(self, "Error", "Please load a video first") + return + + self.left_panel.tracking_status_label.setText("Initializing video tracking...") + self.left_panel.tracking_progress.setVisible(True) + self.sam2_worker.init_video(self.video_path) + + def add_tracked_object(self): + """Add a new object to track using current points""" + if not self.video_tracking_enabled: + QMessageBox.warning( + self, + "Error", + "Video tracking not initialized. Click 'Initialize Video Tracking' first.", + ) + return + + if not self.frame_viewer.points: + QMessageBox.warning( + self, "Error", "Please click points on the object to track" + ) + return + + current_class_id = self.right_panel.class_manager.get_current_class_id() + if current_class_id is None: + QMessageBox.warning(self, "Error", "Please select a class") + return + + # Prepare points for SAM2 + points = [[x, y] for x, y, _ in self.frame_viewer.points] + labels = [label for _, _, label in self.frame_viewer.points] + + # Create new object ID + obj_id = self.next_object_id + self.next_object_id += 1 + + # Store object info + self.tracked_objects[obj_id] = { + "class_id": current_class_id, + "class_name": self.right_panel.class_manager.classes.get( + current_class_id, "unknown" + ), + "masks": {}, + } + + # Add object to SAM2 tracking + self.sam2_worker.add_object( + frame_idx=self.current_frame_idx, + points=points, + labels=labels, + obj_id=obj_id, + ) + + def remove_tracked_object(self): + """Remove selected object from tracking""" + if not hasattr(self.left_panel, "tracked_objects_list"): + return + + current_item = self.left_panel.tracked_objects_list.currentItem() + if current_item: + # Extract object ID from item text + item_text = current_item.text() + try: + obj_id = int(item_text.split(":")[0].split()[-1]) + + # Remove from SAM2 + self.sam2_worker.clear_object(obj_id) + + # Remove from local storage + if obj_id in self.tracked_objects: + del self.tracked_objects[obj_id] + + # Remove from propagation results + for frame_results in self.propagation_results.values(): + if obj_id in frame_results: + del frame_results[obj_id] + + # Remove from list + row = self.left_panel.tracked_objects_list.row(current_item) + self.left_panel.tracked_objects_list.takeItem(row) + + # Refresh display + self.load_current_frame() + + QMessageBox.information(self, "Success", f"Object {obj_id} removed") + + except (ValueError, IndexError): + QMessageBox.warning(self, "Error", "Could not parse object ID") + + def propagate_forward_sam2(self): + """Propagate tracking forward using SAM2""" + if not self.video_tracking_enabled or not self.tracked_objects: + QMessageBox.warning( + self, "Error", "No tracking initialized or objects to track" + ) + return + + self.left_panel.tracking_progress.setVisible(True) + self.sam2_worker.propagate_video( + start_frame=self.current_frame_idx, reverse=False + ) + + def propagate_backward_sam2(self): + """Propagate tracking backward using SAM2""" + if not self.video_tracking_enabled or not self.tracked_objects: + QMessageBox.warning( + self, "Error", "No tracking initialized or objects to track" + ) + return + + self.left_panel.tracking_progress.setVisible(True) + self.sam2_worker.propagate_video( + start_frame=self.current_frame_idx, reverse=True + ) + + def clear_all_tracking(self): + """Clear all SAM2 tracking""" + if self.video_tracking_enabled: + self.sam2_worker.reset_tracking() + self.tracked_objects.clear() + self.propagation_results.clear() + self.next_object_id = 0 + + # Clear tracked objects list + if hasattr(self.left_panel, "tracked_objects_list"): + self.left_panel.tracked_objects_list.clear() + + # Refresh display + self.load_current_frame() + + self.left_panel.tracking_status_label.setText("All tracking cleared") + QMessageBox.information(self, "Success", "All tracking cleared") + + def on_object_selection_changed(self): + """Handle tracked object selection change""" + # Enable/disable remove button based on selection + has_selection = self.left_panel.tracked_objects_list.currentItem() is not None + self.left_panel.remove_object_btn.setEnabled(has_selection) + + def on_point_clicked(self, x: int, y: int, label: int): + """Handle point clicked in frame viewer""" + # Add point to frame viewer + self.frame_viewer.add_point(x, y, label) + + def on_sam2_model_loaded(self, success: bool, message: str): + """Handle SAM2 model loading completion""" + if success: + self.left_panel.tracking_status_label.setText( + "SAM2 model loaded successfully" + ) + print("SAM2 model loaded successfully") + else: + self.left_panel.tracking_status_label.setText( + f"SAM2 model loading failed: {message}" + ) + QMessageBox.critical( + self, "Model Loading Error", f"Failed to load SAM2 model: {message}" + ) + + def on_video_initialized(self, success: bool, message: str): + """Handle video initialization completion""" + self.left_panel.tracking_progress.setVisible(False) + + if success: + self.video_tracking_enabled = True + self.left_panel.tracking_status_label.setText("Video tracking initialized") + + # Enable tracking controls + self.left_panel.add_object_btn.setEnabled(True) + self.left_panel.clear_tracking_btn.setEnabled(True) + + print("Video tracking initialized successfully") + else: + self.left_panel.tracking_status_label.setText( + f"Video initialization failed: {message}" + ) + QMessageBox.critical( + self, + "Video Initialization Error", + f"Failed to initialize video: {message}", + ) + + def on_object_added(self, obj_id: int, object_info: dict): + """Handle object addition completion""" + try: + frame_idx = object_info["frame_idx"] + mask = object_info["mask"] + + # Store the mask for this object + if obj_id in self.tracked_objects: + self.tracked_objects[obj_id]["masks"][frame_idx] = mask + + # Add to tracked objects list + class_name = self.tracked_objects[obj_id]["class_name"] + item_text = f"Object {obj_id}: {class_name}" + self.left_panel.tracked_objects_list.addItem(item_text) + + # Enable propagation controls + self.left_panel.propagate_forward_btn.setEnabled(True) + self.left_panel.propagate_backward_btn.setEnabled(True) + + # Clear points and refresh display + self.frame_viewer.clear_points() + self.load_current_frame() + + self.left_panel.tracking_status_label.setText( + f"Object {obj_id} added successfully" + ) + print(f"Object {obj_id} added successfully to frame {frame_idx}") + else: + print(f"Warning: Object {obj_id} is missing from tracked_objects") + + except Exception as e: + QMessageBox.warning(self, "Error", f"Failed to process added object: {e}") + print(f"Error in on_object_added: {e}") + + def on_propagation_progress(self, frame_idx: int, progress: int): + """Handle propagation progress updates""" + self.left_panel.tracking_progress.setValue(progress) + self.left_panel.tracking_status_label.setText( + f"Propagating... Frame {frame_idx} ({progress}%)" + ) + + def on_propagation_complete(self, results): + """Handle propagation completion""" + self.left_panel.tracking_progress.setVisible(False) + + # Store propagation results + self.propagation_results.update(results) + + # Also store in tracked objects + for frame_idx, frame_results in results.items(): + for obj_id, mask in frame_results.items(): + if obj_id in self.tracked_objects: + self.tracked_objects[obj_id]["masks"][frame_idx] = mask + + # Refresh current frame display + self.load_current_frame() + + num_frames = len(results) + num_objects = len( + set( + obj_id + for frame_results in results.values() + for obj_id in frame_results.keys() + ) + ) + + self.left_panel.tracking_status_label.setText( + f"Propagation complete: {num_frames} frames, {num_objects} objects" + ) + print(f"Propagation complete: {num_frames} frames processed") + + def on_sam2_error(self, error_message: str): + """Handle SAM2 worker errors""" + self.left_panel.tracking_progress.setVisible(False) + self.left_panel.tracking_status_label.setText(f"Error: {error_message}") + QMessageBox.critical( + self, "SAM2 Error", f"SAM2 operation failed: {error_message}" + ) + print(f"SAM2 Error: {error_message}") diff --git a/data_engine/poc/gui/object_class_manager.py b/data_engine/poc/gui/object_class_manager.py new file mode 100644 index 0000000..f44ce13 --- /dev/null +++ b/data_engine/poc/gui/object_class_manager.py @@ -0,0 +1,88 @@ +""" +Object class manager widget for managing object classes in the dataset +""" + +from typing import Dict, Optional + +from PySide6.QtWidgets import ( + QHBoxLayout, + QLabel, + QLineEdit, + QListWidget, + QPushButton, + QVBoxLayout, + QWidget, +) + + +class ObjectClassManager(QWidget): + """Widget for managing object classes""" + + def __init__(self): + super().__init__() + self.classes = {} # id: name mapping + self.setup_ui() + + def setup_ui(self): + layout = QVBoxLayout() + + # Add class controls + add_layout = QHBoxLayout() + self.class_name_input = QLineEdit() + self.class_name_input.setPlaceholderText("Enter class name") + add_button = QPushButton("Add Class") + add_button.clicked.connect(self.add_class) + + add_layout.addWidget(QLabel("Class Name:")) + add_layout.addWidget(self.class_name_input) + add_layout.addWidget(add_button) + + # Class list + self.class_list = QListWidget() + + # Remove class button + remove_button = QPushButton("Remove Selected") + remove_button.clicked.connect(self.remove_selected_class) + + layout.addLayout(add_layout) + layout.addWidget(QLabel("Classes:")) + layout.addWidget(self.class_list) + layout.addWidget(remove_button) + + self.setLayout(layout) + + def add_class(self): + """Add a new class""" + name = self.class_name_input.text().strip() + if name and name not in self.classes.values(): + class_id = len(self.classes) + self.classes[class_id] = name + self.class_list.addItem(f"{class_id}: {name}") + self.class_name_input.clear() + + def remove_selected_class(self): + """Remove selected class""" + current_item = self.class_list.currentItem() + if current_item: + # Parse class ID from item text + text = current_item.text() + class_id = int(text.split(":")[0]) + + # Remove from classes dict + if class_id in self.classes: + del self.classes[class_id] + + # Remove from list + self.class_list.takeItem(self.class_list.row(current_item)) + + def get_current_class_id(self) -> Optional[int]: + """Get currently selected class ID""" + current_item = self.class_list.currentItem() + if current_item: + text = current_item.text() + return int(text.split(":")[0]) + return None + + def get_classes(self) -> Dict[int, str]: + """Get all classes""" + return self.classes.copy() diff --git a/data_engine/poc/gui/right_panel.py b/data_engine/poc/gui/right_panel.py new file mode 100644 index 0000000..d275f08 --- /dev/null +++ b/data_engine/poc/gui/right_panel.py @@ -0,0 +1,71 @@ +""" +Right panel for object class management and annotations display +""" + +from PySide6.QtWidgets import ( + QGroupBox, + QHBoxLayout, + QListWidget, + QPushButton, + QTextEdit, + QVBoxLayout, + QWidget, +) + +from .object_class_manager import ObjectClassManager + + +class RightPanel(QWidget): + """Right panel for classes and annotations""" + + def __init__(self, parent=None): + super().__init__(parent) + self.setup_ui() + + def setup_ui(self): + """Setup the right panel UI""" + layout = QVBoxLayout() + + # Object class manager + self.class_manager = ObjectClassManager() + + # Current annotations + annotations_group = self.create_annotations_group() + + # Statistics + stats_group = self.create_stats_group() + + layout.addWidget(self.class_manager) + layout.addWidget(annotations_group) + layout.addWidget(stats_group) + + self.setLayout(layout) + + def create_annotations_group(self) -> QGroupBox: + """Create annotations display group""" + annotations_group = QGroupBox("Current Frame Annotations") + annotations_layout = QVBoxLayout() + + self.annotations_list = QListWidget() + + self.remove_annotation_btn = QPushButton("Remove Selected Annotation") + + annotations_layout.addWidget(self.annotations_list) + annotations_layout.addWidget(self.remove_annotation_btn) + annotations_group.setLayout(annotations_layout) + + return annotations_group + + def create_stats_group(self) -> QGroupBox: + """Create statistics display group""" + stats_group = QGroupBox("Statistics") + stats_layout = QVBoxLayout() + + self.stats_label = QTextEdit() + self.stats_label.setMaximumHeight(150) + self.stats_label.setReadOnly(True) + + stats_layout.addWidget(self.stats_label) + stats_group.setLayout(stats_layout) + + return stats_group diff --git a/data_engine/poc/gui/sam2_video_worker.py b/data_engine/poc/gui/sam2_video_worker.py new file mode 100644 index 0000000..acca449 --- /dev/null +++ b/data_engine/poc/gui/sam2_video_worker.py @@ -0,0 +1,315 @@ +""" +SAM2 Video Worker for background video tracking operations +""" + +import os +import sys +from pathlib import Path + +import numpy as np +import torch +from PySide6.QtCore import QThread, Signal + +# Add SAM2 to path +sys.path.append("/app/third_party/sam2") +sys.path.append("/app/third_party/sam2/sam2") + +try: + from sam2.build_sam import build_sam2_video_predictor + from sam2.sam2_video_predictor import SAM2VideoPredictor + + SAM2_AVAILABLE = True +except ImportError as e: + print(f"SAM2 not available: {e}") + SAM2VideoPredictor = None + SAM2_AVAILABLE = False + + +class SAM2VideoWorker(QThread): + """Worker for SAM2 video tracking operations""" + + # Signals + video_initialized = Signal(bool, str) # success, message + object_added = Signal(int, dict) # obj_id, object_info + propagation_progress = Signal(int, int, str) # current_frame, total_frames, message + propagation_complete = Signal(dict) # frame_idx -> {obj_id: mask} + error_occurred = Signal(str) + model_loaded = Signal(bool, str) # success, message + + def __init__(self): + super().__init__() + self.predictor = None + self.inference_state = None + self.video_path = None + self.current_task = None + self.task_data = None + self.is_running = False + + def load_model(self, model_name="sam2_hiera_base_plus.pt"): + """Load SAM2 video predictor model""" + self.current_task = "load_model" + self.task_data = {"model_name": model_name} + self.start() + + def init_video(self, video_path): + """Initialize video for tracking""" + self.video_path = video_path + self.current_task = "init_video" + self.start() + + def add_object(self, frame_idx, points, labels, obj_id): + """Add new object for tracking""" + self.current_task = "add_object" + self.task_data = { + "frame_idx": frame_idx, + "points": points, + "labels": labels, + "obj_id": obj_id, + } + self.start() + + def propagate_video(self, start_frame=None, reverse=False, max_frames=None): + """Propagate tracking through video""" + self.current_task = "propagate" + self.task_data = { + "start_frame": start_frame, + "reverse": reverse, + "max_frames": max_frames, + } + self.start() + + def stop_current_task(self): + """Stop the current running task""" + self.is_running = False + + def run(self) -> None: + """Execute the current task""" + self.is_running = True + try: + if self.current_task == "load_model": + self._load_model() + elif self.current_task == "init_video": + self._init_video() + elif self.current_task == "add_object": + self._add_object() + elif self.current_task == "propagate": + self._propagate() + except Exception as e: + self.error_occurred.emit(f"Error in {self.current_task}: {str(e)}") + finally: + self.is_running = False + + def _load_model(self): + """Load SAM2 video predictor model""" + if not SAM2_AVAILABLE: + self.model_loaded.emit(False, "SAM2 not available") + return + + try: + model_name = self.task_data.get("model_name", "sam2_hiera_base_plus.pt") + + # Try different possible locations for the model + possible_paths = [ + f"/app/sam2_models/{model_name}", + f"/app/models/{model_name}", + f"/app/third_party/sam2/checkpoints/{model_name}", + model_name, # Let SAM2 handle download + ] + + model_path = None + for path in possible_paths: + if os.path.exists(path): + model_path = path + break + + if model_path is None: + model_path = model_name # Let SAM2 handle it + + # Determine config based on model name + if "tiny" in model_name: + config = "sam2_hiera_t.yaml" + elif "small" in model_name: + config = "sam2_hiera_s.yaml" + elif "base_plus" in model_name: + config = "sam2_hiera_b+.yaml" + elif "large" in model_name: + config = "sam2_hiera_l.yaml" + else: + config = "sam2_hiera_b+.yaml" # default + + # Try to find config file + config_path = f"/app/third_party/sam2/sam2/configs/{config}" + if not os.path.exists(config_path): + # Fallback to default config + config_path = config + + self.predictor = build_sam2_video_predictor(config_path, model_path) + self.model_loaded.emit(True, f"SAM2 model {model_name} loaded successfully") + + except Exception as e: + self.model_loaded.emit(False, f"Failed to load SAM2 model: {e}") + + def _init_video(self): + """Initialize video in SAM2""" + if not self.predictor or not self.video_path: + self.video_initialized.emit(False, "Model or video path not set") + return + + try: + self.inference_state = self.predictor.init_state( + video_path=self.video_path, + offload_video_to_cpu=True, # Save GPU memory + offload_state_to_cpu=False, + ) + + num_frames = self.inference_state.get("num_frames", 0) + self.video_initialized.emit( + True, f"Video initialized with {num_frames} frames" + ) + + except Exception as e: + self.video_initialized.emit(False, f"Failed to initialize video: {e}") + + def _add_object(self): + """Add object to tracking""" + if not self.predictor or not self.inference_state: + self.error_occurred.emit("SAM2 not initialized") + return + + data = self.task_data + try: + # Convert numpy arrays to proper format + points = np.array(data["points"], dtype=np.float32) + labels = np.array(data["labels"], dtype=np.int32) + + frame_idx, obj_ids, video_res_masks = self.predictor.add_new_points_or_box( + inference_state=self.inference_state, + frame_idx=data["frame_idx"], + obj_id=data["obj_id"], + points=points, + labels=labels, + normalize_coords=True, + ) + + # Extract mask for the new object + obj_idx = obj_ids.index(data["obj_id"]) + mask = video_res_masks[obj_idx].cpu().numpy() + + print( + f"SAM2 mask shape: {mask.shape}, dtype: {mask.dtype}, min: {mask.min()}, max: {mask.max()}" + ) + + # Ensure mask is in the correct format (2D boolean/binary) + if mask.ndim > 2: + mask = mask.squeeze() + if mask.ndim != 2: + raise ValueError(f"Mask has invalid dimensions: {mask.shape}") + + # Convert to binary mask if needed + if mask.dtype == np.bool_: + # Already boolean, convert to uint8 for consistency + mask = mask.astype(np.uint8) * 255 + elif mask.max() <= 1.0: + # Floating point mask, threshold and convert + mask = (mask > 0.5).astype(np.uint8) * 255 + else: + # Already in uint8 range + mask = mask.astype(np.uint8) + + object_info = { + "obj_id": data["obj_id"], + "frame_idx": frame_idx, + "mask": mask, + } + self.object_added.emit(data["obj_id"], object_info) + + except Exception as e: + self.error_occurred.emit(f"Failed to add object: {e}") + + def _propagate(self): + """Propagate tracking through video""" + if not self.predictor or not self.inference_state: + self.error_occurred.emit("SAM2 not initialized") + return + + data = self.task_data + all_results = {} + + try: + total_frames = self.inference_state["num_frames"] + current_frame = 0 + + for ( + frame_idx, + obj_ids, + video_res_masks, + ) in self.predictor.propagate_in_video( + self.inference_state, + start_frame_idx=data.get("start_frame"), + max_frame_num_to_track=data.get("max_frames"), + reverse=data.get("reverse", False), + ): + if not self.is_running: + break + + # Convert masks to numpy and organize by object + frame_results = {} + for i, obj_id in enumerate(obj_ids): + mask = video_res_masks[i].cpu().numpy() + + # Ensure mask is in the correct format (2D boolean/binary) + if mask.ndim > 2: + mask = mask.squeeze() + if mask.ndim != 2: + print( + f"Warning: Mask for obj {obj_id} has invalid dimensions: {mask.shape}" + ) + continue + + # Convert to binary mask if needed + if mask.dtype == np.bool_: + # Already boolean, convert to uint8 for consistency + mask = mask.astype(np.uint8) * 255 + elif mask.max() <= 1.0: + # Floating point mask, threshold and convert + mask = (mask > 0.5).astype(np.uint8) * 255 + else: + # Already in uint8 range + mask = mask.astype(np.uint8) + + frame_results[obj_id] = mask + + all_results[frame_idx] = frame_results + current_frame += 1 + + # Emit progress + progress_msg = f"Processing frame {frame_idx}/{total_frames}" + self.propagation_progress.emit( + current_frame, total_frames, progress_msg + ) + + if self.is_running: + self.propagation_complete.emit(all_results) + + except Exception as e: + self.error_occurred.emit(f"Failed to propagate: {e}") + + def clear_object(self, obj_id): + """Clear/remove an object from tracking""" + if not self.predictor or not self.inference_state: + return + + try: + self.predictor.remove_object(self.inference_state, obj_id) + except Exception as e: + self.error_occurred.emit(f"Failed to remove object {obj_id}: {e}") + + def reset_tracking(self): + """Reset all tracking""" + if not self.predictor or not self.inference_state: + return + + try: + self.predictor.reset_state(self.inference_state) + except Exception as e: + self.error_occurred.emit(f"Failed to reset tracking: {e}") diff --git a/data_engine/poc/gui/workers.py b/data_engine/poc/gui/workers.py new file mode 100644 index 0000000..90c86e1 --- /dev/null +++ b/data_engine/poc/gui/workers.py @@ -0,0 +1,72 @@ +""" +Background workers for video processing and SAM operations +""" + +from pathlib import Path + +import cv2 +import numpy as np +import torch +from PySide6.QtCore import QRunnable, QThread, Signal + + +class VideoProcessor(QThread): + """Thread for processing video frames and SAM2 operations""" + + frame_processed = Signal(int, np.ndarray) + progress_updated = Signal(int) + finished = Signal() + error_occurred = Signal(str) + + def __init__(self, video_path: str, output_dir: str): + super().__init__() + self.video_path = video_path + self.output_dir = output_dir + self.cache_dir = Path(output_dir) / "cache" + self.cache_dir.mkdir(exist_ok=True) + + def run(self): + try: + cap = cv2.VideoCapture(self.video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + frame_idx = 0 + while True: + ret, frame = cap.read() + if not ret: + break + + # Cache frame + frame_path = self.cache_dir / f"frame_{frame_idx:06d}.jpg" + cv2.imwrite(str(frame_path), frame) + + self.frame_processed.emit(frame_idx, frame) + self.progress_updated.emit(int((frame_idx / total_frames) * 100)) + frame_idx += 1 + + cap.release() + self.finished.emit() + + except Exception as e: + self.error_occurred.emit(str(e)) + + +class SAMProcessor(QRunnable): + """Runnable for SAM2 segmentation processing""" + + def __init__(self, model, frame, points, labels, frame_idx, callback): + super().__init__() + self.model = model + self.frame = frame + self.points = points + self.labels = labels + self.frame_idx = frame_idx + self.callback = callback + + def run(self): + try: + # Process with SAM2 + results = self.model(self.frame, points=self.points, labels=self.labels) + self.callback(self.frame_idx, results) + except Exception as e: + print(f"SAM processing error: {e}") diff --git a/data_engine/poc/main.py b/data_engine/poc/main.py new file mode 100644 index 0000000..937ea2b --- /dev/null +++ b/data_engine/poc/main.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +""" +Data Engine GUI Application +An automatic data engine that uses visual prompts to interact with AI models +for generating YOLO training datasets using SAM2 segmentation. +""" + +import sys + +# Import modularized GUI components +from gui import DataEngineMainWindow +from PySide6.QtWidgets import QApplication + + +def main(): + app = QApplication(sys.argv) + + window = DataEngineMainWindow() + window.show() + + sys.exit(app.exec()) + + +if __name__ == "__main__": + main() diff --git a/data_engine/poc/poc_screenshot.png b/data_engine/poc/poc_screenshot.png new file mode 100644 index 0000000..6320292 Binary files /dev/null and b/data_engine/poc/poc_screenshot.png differ diff --git a/data_engine/poc/requirements.txt b/data_engine/poc/requirements.txt new file mode 100644 index 0000000..3ed3f4f --- /dev/null +++ b/data_engine/poc/requirements.txt @@ -0,0 +1,13 @@ +PySide6>=6.5.0 +ultralytics>=8.0.0 +opencv-python>=4.8.0 +numpy>=1.21.0 +Pillow>=9.0.0 +torch>=1.13.0 +torchvision>=0.14.0 +matplotlib>=3.5.0 +tqdm>=4.64.0 +hydra-core>=1.3.0 +iopath>=0.1.9 +fvcore>=0.1.5 +eva-decord>=0.6.1 diff --git a/data_engine/poc/utils.py b/data_engine/poc/utils.py new file mode 100644 index 0000000..7cbee43 --- /dev/null +++ b/data_engine/poc/utils.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 +""" +Utility functions for the Data Engine application +""" + +import json +import os +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import cv2 +import numpy as np +import torch + + +def mask_to_polygon( + mask: np.ndarray, epsilon_factor: float = 0.01 +) -> List[List[float]]: + """ + Convert a binary mask to polygon coordinates in YOLO format + + Args: + mask: Binary mask as numpy array + epsilon_factor: Factor for polygon approximation (smaller = more precise) + + Returns: + List of normalized coordinates [x1, y1, x2, y2, ...] + """ + # Find contours + contours, _ = cv2.findContours( + mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + + if not contours: + return [] + + # Get the largest contour + largest_contour = max(contours, key=cv2.contourArea) + + # Approximate polygon to reduce points + epsilon = epsilon_factor * cv2.arcLength(largest_contour, True) + approx_polygon = cv2.approxPolyDP(largest_contour, epsilon, True) + + # Extract coordinates and normalize + h, w = mask.shape + normalized_coords = [] + + for point in approx_polygon: + x, y = point[0] + normalized_coords.extend([x / w, y / h]) + + return normalized_coords + + +def polygon_to_mask( + polygon: List[float], img_height: int, img_width: int +) -> np.ndarray: + """ + Convert polygon coordinates to binary mask + + Args: + polygon: List of normalized coordinates [x1, y1, x2, y2, ...] + img_height: Image height + img_width: Image width + + Returns: + Binary mask as numpy array + """ + # Convert normalized coordinates back to pixel coordinates + points = [] + for i in range(0, len(polygon), 2): + x = int(polygon[i] * img_width) + y = int(polygon[i + 1] * img_height) + points.append([x, y]) + + # Create mask + mask = np.zeros((img_height, img_width), dtype=np.uint8) + if len(points) >= 3: + cv2.fillPoly(mask, [np.array(points, dtype=np.int32)], 255) + + return mask + + +def save_project_config(config: Dict, config_path: str): + """Save project configuration to JSON file""" + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + +def load_project_config(config_path: str) -> Dict: + """Load project configuration from JSON file""" + if os.path.exists(config_path): + with open(config_path, "r") as f: + return json.load(f) + return {} + + +def create_yolo_yaml(classes: Dict[int, str], dataset_path: str) -> str: + """ + Create YOLO dataset YAML configuration file + + Args: + classes: Dictionary mapping class IDs to class names + dataset_path: Path to the dataset directory + + Returns: + Path to the created YAML file + """ + yaml_content = f"""# YOLO Dataset Configuration +# Generated by SAM2 Data Engine + +path: {dataset_path} # dataset root dir +train: images # train images (relative to 'path') +val: images # val images (relative to 'path') +test: # test images (optional) + +# Classes +nc: {len(classes)} # number of classes +names: {list(classes.values())} # class names +""" + + yaml_path = Path(dataset_path) / "dataset.yaml" + with open(yaml_path, "w") as f: + f.write(yaml_content) + + return str(yaml_path) + + +def validate_yolo_annotation( + annotation_line: str, img_width: int, img_height: int +) -> bool: + """ + Validate a YOLO annotation line + + Args: + annotation_line: YOLO format annotation string + img_width: Image width + img_height: Image height + + Returns: + True if annotation is valid + """ + try: + parts = annotation_line.strip().split() + if len(parts) < 7: # class_id + at least 3 coordinate pairs + return False + + # Check class ID + class_id = int(parts[0]) + if class_id < 0: + return False + + # Check coordinates + coords = [float(x) for x in parts[1:]] + if len(coords) % 2 != 0: # Must be pairs + return False + + # Check coordinate bounds + for i in range(0, len(coords), 2): + x, y = coords[i], coords[i + 1] + if not (0 <= x <= 1 and 0 <= y <= 1): + return False + + return True + + except (ValueError, IndexError): + return False + + +def calculate_mask_area(mask: np.ndarray) -> float: + """Calculate the area of a binary mask as percentage of image""" + total_pixels = mask.shape[0] * mask.shape[1] + mask_pixels = np.sum(mask > 0) + return (mask_pixels / total_pixels) * 100 + + +def get_mask_bbox(mask: np.ndarray) -> Tuple[int, int, int, int]: + """ + Get bounding box coordinates from binary mask + + Returns: + Tuple of (x_min, y_min, x_max, y_max) + """ + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + + if not np.any(rows) or not np.any(cols): + return (0, 0, 0, 0) + + y_min, y_max = np.where(rows)[0][[0, -1]] + x_min, x_max = np.where(cols)[0][[0, -1]] + + return (x_min, y_min, x_max, y_max) + + +def resize_mask(mask: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray: + """ + Resize a binary mask to target size + + Args: + mask: Binary mask + target_size: (width, height) + + Returns: + Resized mask + """ + return cv2.resize( + mask.astype(np.uint8), target_size, interpolation=cv2.INTER_NEAREST + ) + + +def merge_masks(masks: List[np.ndarray]) -> np.ndarray: + """ + Merge multiple binary masks into one + + Args: + masks: List of binary masks + + Returns: + Merged mask + """ + if not masks: + return np.array([]) + + result = masks[0].copy() + for mask in masks[1:]: + result = np.logical_or(result, mask) + + return result.astype(np.uint8) + + +def apply_mask_overlay( + image: np.ndarray, + mask: np.ndarray, + color: Tuple[int, int, int] = (0, 255, 0), + alpha: float = 0.5, +) -> np.ndarray: + """ + Apply a colored mask overlay to an image + + Args: + image: RGB image + mask: Binary mask + color: RGB color for the overlay + alpha: Transparency of the overlay + + Returns: + Image with mask overlay + """ + overlay = image.copy() + overlay[mask > 0] = color + + return cv2.addWeighted(image, 1 - alpha, overlay, alpha, 0) + + +def extract_frame_info(video_path: str) -> Dict: + """ + Extract information from video file + + Args: + video_path: Path to video file + + Returns: + Dictionary with video information + """ + cap = cv2.VideoCapture(video_path) + + info = { + "total_frames": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), + "fps": cap.get(cv2.CAP_PROP_FPS), + "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), + "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), + "duration": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) / cap.get(cv2.CAP_PROP_FPS) + if cap.get(cv2.CAP_PROP_FPS) > 0 + else 0, + } + + cap.release() + return info + + +def create_directory_structure(base_path: str) -> Dict[str, str]: + """ + Create the directory structure for the data engine project + + Args: + base_path: Base directory path + + Returns: + Dictionary mapping folder names to paths + """ + base_path = Path(base_path) + + directories = { + "cache": base_path / "cache", + "frames": base_path / "cache" / "frames", + "masks": base_path / "cache" / "masks", + "features": base_path / "cache" / "features", + "yolo_dataset": base_path / "yolo_dataset", + "images": base_path / "yolo_dataset" / "images", + "labels": base_path / "yolo_dataset" / "labels", + "configs": base_path / "configs", + "exports": base_path / "exports", + } + + # Create all directories + for dir_path in directories.values(): + dir_path.mkdir(parents=True, exist_ok=True) + + return {name: str(path) for name, path in directories.items()} + + +def cleanup_cache(cache_dir: str, keep_recent: int = 100): + """ + Clean up old cache files, keeping only the most recent ones + + Args: + cache_dir: Path to cache directory + keep_recent: Number of recent files to keep + """ + cache_path = Path(cache_dir) + if not cache_path.exists(): + return + + # Get all files sorted by modification time + files = sorted(cache_path.glob("*"), key=lambda x: x.stat().st_mtime, reverse=True) + + # Remove old files + for file_path in files[keep_recent:]: + try: + file_path.unlink() + except OSError: + pass # Ignore errors + + +def export_statistics(annotations: Dict, output_path: str): + """ + Export annotation statistics to JSON file + + Args: + annotations: Dictionary of annotations by frame + output_path: Path to output JSON file + """ + stats = { + "total_frames": len(annotations), + "total_annotations": sum(len(anns) for anns in annotations.values()), + "annotations_per_frame": { + str(frame_idx): len(anns) for frame_idx, anns in annotations.items() + }, + "class_distribution": {}, + "average_annotations_per_frame": sum(len(anns) for anns in annotations.values()) + / len(annotations) + if annotations + else 0, + } + + # Calculate class distribution + for frame_annotations in annotations.values(): + for annotation in frame_annotations: + class_id = annotation.get("class_id", "unknown") + class_name = annotation.get("class_name", "unknown") + key = f"{class_id}_{class_name}" + stats["class_distribution"][key] = ( + stats["class_distribution"].get(key, 0) + 1 + ) + + with open(output_path, "w") as f: + json.dump(stats, f, indent=2) diff --git a/data_engine/poc/videos/example.mp4 b/data_engine/poc/videos/example.mp4 new file mode 100644 index 0000000..283fe64 Binary files /dev/null and b/data_engine/poc/videos/example.mp4 differ diff --git a/hocker b/hocker index 20b74cc..474d715 120000 --- a/hocker +++ b/hocker @@ -1 +1 @@ -/home/catkin_ws/src/hydrus-software-stack/docker/hydrus-docker/hocker.py \ No newline at end of file +/home/cesar/Projects/hydrus-software-stack/docker/hydrus-docker/hocker.py \ No newline at end of file diff --git a/hydrus-cli b/hydrus-cli index 3074f4a..03199d0 100755 --- a/hydrus-cli +++ b/hydrus-cli @@ -5,9 +5,6 @@ # Get the directory where this wrapper is located (should be project root) SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -# Set HYDRUS_ROOT environment variable to ensure test system works correctly -export HYDRUS_ROOT="$SCRIPT_DIR" - # Change to the project directory to ensure proper module resolution cd "$SCRIPT_DIR" diff --git a/scripts/I want to work.md b/scripts/I want to work.md new file mode 100644 index 0000000..0ca52b8 --- /dev/null +++ b/scripts/I want to work.md @@ -0,0 +1,14 @@ +I want to work on a automatic data_engine that usees viual prompts to interact with the ai models. What I want is to create a gui application build in python with the library pyside6. This gui should take the user inputs and ingest it t an ai model to at the end generate a yolo training dataset. So we will be using the model sam2 for automatic video segmentation. We will be able to go back and firth select the exact frame we want to work on. Cache inside a folder every individual. image with their generated mask. Each mask should have assigned an object with a name. it should save in the yolo format the segmentation the class and of the object type classification. + +We will be using the model sam2 from ultralytics please check this link for reference https://docs.ultralytics.com/models/sam-2/ + +some of the requirements are the following + +Move Back Forth in the frames +Conduct a Forward and Backwards Propagation from the selected frame. +Cache the frames in the file system. +Cache the Image-Encoder output from From SAM2 if using it. save it in a .pt file for later. +Save all the outputs in a YOLO dataset format. + + +create this on a folder called_data_engine diff --git a/third_party/sam2 b/third_party/sam2 new file mode 160000 index 0000000..2b90b9f --- /dev/null +++ b/third_party/sam2 @@ -0,0 +1 @@ +Subproject commit 2b90b9f5ceec907a1c18123530e92e794ad901a4