From 81384ce5c2fed20c71169c76561c185f5e8f4567 Mon Sep 17 00:00:00 2001 From: gapilongo Date: Sat, 25 Oct 2025 18:48:21 +0100 Subject: [PATCH] refactoring --- Backup/main.py | 833 +++++++++++++++++ {src/lg_sotf/utils => Backup}/monitoring.py | 0 src/lg_sotf/api/__init__.py | 5 + src/lg_sotf/api/app.py | 213 +++++ src/lg_sotf/api/dependencies.py | 30 + src/lg_sotf/api/models/__init__.py | 0 src/lg_sotf/api/models/alerts.py | 19 + src/lg_sotf/api/models/ingestion.py | 23 + src/lg_sotf/api/models/metrics.py | 29 + src/lg_sotf/api/models/websocket.py | 0 src/lg_sotf/api/models/workflows.py | 41 + src/lg_sotf/api/routers/__init__.py | 21 + src/lg_sotf/api/routers/alerts.py | 520 +++++++++++ src/lg_sotf/api/routers/correlations.py | 252 ++++++ src/lg_sotf/api/routers/dashboard.py | 191 ++++ src/lg_sotf/api/routers/escalations.py | 129 +++ src/lg_sotf/api/routers/ingestion.py | 233 +++++ src/lg_sotf/api/routers/metrics.py | 201 ++++ src/lg_sotf/api/routers/websocket.py | 47 + src/lg_sotf/api/utils/__init__.py | 0 src/lg_sotf/api/utils/data_processing.py | 0 src/lg_sotf/api/utils/websocket.py | 91 ++ src/lg_sotf/app_initializer.py | 689 ++++++++++++++ src/lg_sotf/core/graph.py | 374 ++++++++ src/lg_sotf/core/workflow.py | 214 ++--- src/lg_sotf/main.py | 855 +----------------- .../test_state/test_state_manager.py | 571 ++++++++++++ tests/unit/test_core/test_workflow.py | 676 ++++++++++++-- 28 files changed, 5218 insertions(+), 1039 deletions(-) create mode 100644 Backup/main.py rename {src/lg_sotf/utils => Backup}/monitoring.py (100%) create mode 100644 src/lg_sotf/api/__init__.py create mode 100644 src/lg_sotf/api/app.py create mode 100644 src/lg_sotf/api/dependencies.py create mode 100644 src/lg_sotf/api/models/__init__.py create mode 100644 src/lg_sotf/api/models/alerts.py create mode 100644 src/lg_sotf/api/models/ingestion.py create mode 100644 src/lg_sotf/api/models/metrics.py create mode 100644 src/lg_sotf/api/models/websocket.py create mode 100644 src/lg_sotf/api/models/workflows.py create mode 100644 src/lg_sotf/api/routers/__init__.py create mode 100644 src/lg_sotf/api/routers/alerts.py create mode 100644 src/lg_sotf/api/routers/correlations.py create mode 100644 src/lg_sotf/api/routers/dashboard.py create mode 100644 src/lg_sotf/api/routers/escalations.py create mode 100644 src/lg_sotf/api/routers/ingestion.py create mode 100644 src/lg_sotf/api/routers/metrics.py create mode 100644 src/lg_sotf/api/routers/websocket.py create mode 100644 src/lg_sotf/api/utils/__init__.py create mode 100644 src/lg_sotf/api/utils/data_processing.py create mode 100644 src/lg_sotf/api/utils/websocket.py create mode 100644 src/lg_sotf/app_initializer.py create mode 100644 src/lg_sotf/core/graph.py create mode 100644 tests/unit/test_core/test_state/test_state_manager.py diff --git a/Backup/main.py b/Backup/main.py new file mode 100644 index 00000000..0e451885 --- /dev/null +++ b/Backup/main.py @@ -0,0 +1,833 @@ +""" +Main application entry point for LG-SOTF - Production Version. + +This module provides the main application lifecycle management including: +- Component initialization and dependency injection +- Continuous alert ingestion and processing +- Health monitoring and metrics collection +- Graceful shutdown with resource cleanup +""" + +import asyncio +import logging +import signal +import sys +from datetime import datetime, timedelta +from pathlib import Path +from typing import Optional, Set + +from lg_sotf.audit.logger import AuditLogger +from lg_sotf.audit.metrics import MetricsCollector +from lg_sotf.core.config.manager import ConfigManager +from lg_sotf.core.exceptions import LG_SOTFError +from lg_sotf.core.state.manager import StateManager +from lg_sotf.core.workflow import WorkflowEngine +from lg_sotf.storage.postgres import PostgreSQLStorage +from lg_sotf.storage.redis import RedisStorage + + +class LG_SOTFApplication: + """Main application class for LG-SOTF.""" + + def __init__(self, config_path: Optional[str] = None, setup_signal_handlers: bool = True): + """Initialize the application. + + Args: + config_path: Path to configuration file + setup_signal_handlers: Whether to setup signal handlers (disable when running under uvicorn) + """ + self.config_path = config_path + self.config_manager = None + self.state_manager = None + self.workflow_engine = None + self.audit_logger = None + self.metrics = None + self.postgres_storage = None + self.redis_storage = None + + # Application state + self.running = False + self.initialized = False + + # Task tracking for graceful shutdown + self._active_tasks: Set[asyncio.Task] = set() + self._shutdown_event = asyncio.Event() + + # Ingestion tracking + self._last_ingestion_poll: Optional[datetime] = None + self._last_health_check: Optional[datetime] = None + self._ingestion_lock = asyncio.Lock() + + # Setup signal handlers (only when not running under uvicorn) + if setup_signal_handlers: + self._setup_signal_handlers() + + def _setup_signal_handlers(self): + """Setup signal handlers for graceful shutdown.""" + def signal_handler(signum, frame): + """Handle shutdown signals.""" + logging.info(f"Received signal {signum}, initiating graceful shutdown...") + self.running = False + # Use call_soon_threadsafe to safely set the event from signal handler + try: + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(self._shutdown_event.set) + except RuntimeError: + # If no loop is running, just set it directly (shouldn't happen in normal flow) + self._shutdown_event.set() + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + async def initialize(self): + """Initialize all application components. + + Raises: + LG_SOTFError: If initialization fails + """ + try: + logging.info("=== Initializing LG-SOTF Application ===") + + # Load configuration + self.config_manager = ConfigManager(self.config_path) + logging.info("✓ Configuration loaded") + + # Validate configuration + self.config_manager.validate() + logging.info("✓ Configuration validated") + + # Initialize audit and metrics + self.audit_logger = AuditLogger() + self.metrics = MetricsCollector(self.config_manager) + logging.info("✓ Audit and metrics initialized") + + # Initialize storage backends + await self._initialize_storage() + logging.info("✓ Storage backends initialized") + + # Initialize state manager + self.state_manager = StateManager(self.postgres_storage) + logging.info("✓ State manager initialized") + + # Initialize workflow engine (handles agent initialization) + # Pass Redis storage and let the workflow engine create the tool orchestrator + self.workflow_engine = WorkflowEngine( + self.config_manager, + self.state_manager, + redis_storage=self.redis_storage, + tool_orchestrator=None # Created internally by workflow engine + ) + await self.workflow_engine.initialize() + logging.info("✓ Workflow engine initialized") + + # Verify agents are registered + from lg_sotf.agents.registry import agent_registry + stats = agent_registry.get_registry_stats() + logging.info( + f"✓ Agent registry: {stats['agent_types_count']} types, " + f"{stats['agent_instances_count']} instances, " + f"{len(stats['initialized_agents'])} initialized" + ) + + # Log application start + self.audit_logger.log_application_start( + config_path=self.config_path, + version="0.1.0" + ) + + self.initialized = True + logging.info("=== LG-SOTF Application Initialized Successfully ===\n") + + except Exception as e: + logging.error(f"Failed to initialize application: {e}", exc_info=True) + raise LG_SOTFError(f"Application initialization failed: {e}") + + async def _initialize_storage(self): + """Initialize storage backends. + + Raises: + Exception: If storage initialization fails + """ + try: + # Initialize PostgreSQL + db_config = self.config_manager.get_database_config() + connection_string = ( + f"postgresql://{db_config.username}:{db_config.password}@" + f"{db_config.host}:{db_config.port}/{db_config.database}" + ) + + self.postgres_storage = PostgreSQLStorage(connection_string) + await self.postgres_storage.initialize() + logging.info(f" - PostgreSQL connected: {db_config.host}:{db_config.port}") + + # Initialize Redis + redis_config = self.config_manager.get_redis_config() + redis_password = f":{redis_config.password}@" if redis_config.password else "" + redis_connection_string = ( + f"redis://{redis_password}{redis_config.host}:{redis_config.port}/{redis_config.db}" + ) + + self.redis_storage = RedisStorage(redis_connection_string) + await self.redis_storage.initialize() + logging.info(f" - Redis connected: {redis_config.host}:{redis_config.port}") + + except Exception as e: + logging.error(f"Failed to initialize storage: {e}", exc_info=True) + raise + + async def run(self): + """Run the main application loop. + + This method handles: + - Continuous alert ingestion and processing + - Periodic health checks + - Graceful shutdown on signal + """ + try: + self.running = True + logging.info("🚀 LG-SOTF Application Started") + logging.info("Press Ctrl+C to shutdown gracefully\n") + + # Create background tasks + ingestion_task = asyncio.create_task(self._ingestion_loop()) + health_check_task = asyncio.create_task(self._health_check_loop()) + + # Wait for shutdown signal + await self._shutdown_event.wait() + + # Cancel background tasks + logging.info("Stopping background tasks...") + ingestion_task.cancel() + health_check_task.cancel() + + # Wait for tasks to complete + await asyncio.gather(ingestion_task, health_check_task, return_exceptions=True) + + except asyncio.CancelledError: + logging.info("Main loop cancelled") + except Exception as e: + logging.error(f"Unexpected error in main loop: {e}", exc_info=True) + finally: + await self.shutdown() + + async def _ingestion_loop(self): + """Continuous ingestion loop. + + Polls configured sources at regular intervals and processes alerts. + """ + try: + while self.running: + try: + await self._process_alerts() + except asyncio.CancelledError: + raise + except Exception as e: + logging.error(f"Error in ingestion loop: {e}", exc_info=True) + self.metrics.increment_counter("ingestion_loop_errors") + await asyncio.sleep(5) # Back off on error + + # Sleep briefly to prevent tight loop + await asyncio.sleep(1) + + except asyncio.CancelledError: + logging.info("Ingestion loop cancelled") + + async def _health_check_loop(self): + """Periodic health check loop. + + Performs system health checks at regular intervals. + Responsive to shutdown signals. + """ + try: + while self.running: + try: + await self._perform_health_checks() + except asyncio.CancelledError: + raise + except Exception as e: + logging.error(f"Error in health check loop: {e}", exc_info=True) + self.metrics.increment_counter("health_check_errors") + + # Wait before next health check, but wake up on shutdown + try: + await asyncio.wait_for(self._shutdown_event.wait(), timeout=60) + # If we get here, shutdown was triggered + break + except asyncio.TimeoutError: + # Timeout is normal, continue to next health check + pass + + except asyncio.CancelledError: + logging.info("Health check loop cancelled") + + async def _process_alerts(self): + """Process alerts from ingestion sources. + + This method: + - Respects polling interval configuration + - Enforces max concurrent alert limit + - Tracks tasks for graceful shutdown + """ + # Check if ingestion agent is available + if not self.workflow_engine or "ingestion" not in self.workflow_engine.agents: + return + + ingestion_agent = self.workflow_engine.agents["ingestion"] + + # Get polling configuration + ingestion_config = self.config_manager.get_agent_config("ingestion") + polling_interval = ingestion_config.get("polling_interval", 60) + max_concurrent = ingestion_agent.max_concurrent_alerts + + # Check if it's time to poll + if self._last_ingestion_poll is not None: + time_since_poll = (datetime.utcnow() - self._last_ingestion_poll).total_seconds() + if time_since_poll < polling_interval: + return + + # Use lock to prevent concurrent polling + if self._ingestion_lock.locked(): + return + + async with self._ingestion_lock: + try: + # Check active task count + active_count = len(self._active_tasks) + if active_count >= max_concurrent: + logging.warning( + f"Max concurrent alerts reached ({active_count}/{max_concurrent}), " + "skipping this poll cycle" + ) + self.metrics.increment_counter("ingestion_poll_skipped_max_concurrent") + return + + # Poll for new alerts + logging.debug("Polling ingestion sources...") + new_alerts = await ingestion_agent.poll_sources() + + if not new_alerts: + self._last_ingestion_poll = datetime.utcnow() + return + + logging.info(f"📥 Ingestion: Found {len(new_alerts)} new alerts") + self.metrics.increment_counter("ingestion_alerts_received", len(new_alerts)) + + # Process alerts respecting concurrency limit + processed_count = 0 + for alert in new_alerts: + # Check if we can process more alerts + if len(self._active_tasks) >= max_concurrent: + remaining = len(new_alerts) - processed_count + logging.warning( + f"Max concurrent limit reached, " + f"queueing {remaining} alerts for next cycle" + ) + self.metrics.increment_counter("ingestion_alerts_queued", remaining) + break + + try: + # Create workflow task + task = asyncio.create_task( + self._process_single_workflow(alert["id"], alert) + ) + + # Track task + self._active_tasks.add(task) + task.add_done_callback(self._active_tasks.discard) + + processed_count += 1 + + except Exception as e: + logging.error( + f"Failed to create workflow task for alert {alert.get('id', 'unknown')}: {e}", + exc_info=True + ) + self.metrics.increment_counter("workflow_creation_errors") + + logging.info( + f"✓ Created {processed_count} workflow tasks " + f"({len(self._active_tasks)} active)" + ) + self.metrics.set_gauge("active_workflow_tasks", len(self._active_tasks)) + + # Update last poll time + self._last_ingestion_poll = datetime.utcnow() + self.metrics.record_histogram("ingestion_poll_interval", polling_interval) + + except Exception as e: + logging.error(f"Ingestion polling error: {e}", exc_info=True) + self.metrics.increment_counter("ingestion_poll_errors") + + async def _process_single_workflow(self, alert_id: str, alert_data: dict): + """Process a single alert through the workflow. + + Args: + alert_id: Alert identifier + alert_data: Alert data dictionary (already ingested by polling loop) + """ + start_time = datetime.utcnow() + + try: + logging.debug(f"Processing workflow for alert {alert_id}") + + # Skip ingestion node since alert is already ingested by the polling loop + result = await self.workflow_engine.execute_workflow( + alert_id, + alert_data, + skip_ingestion=True # Alert already normalized by ingestion agent polling + ) + + # Calculate processing time + processing_time = (datetime.utcnow() - start_time).total_seconds() + + logging.info( + f"✓ Alert {alert_id} processed: " + f"status={result.get('triage_status', 'unknown')}, " + f"confidence={result.get('confidence_score', 0)}, " + f"time={processing_time:.2f}s" + ) + + # Record metrics + self.metrics.increment_counter("workflow_success") + self.metrics.record_histogram("workflow_processing_time", processing_time) + + except asyncio.CancelledError: + logging.info(f"Workflow for alert {alert_id} cancelled (shutdown)") + raise + except Exception as e: + processing_time = (datetime.utcnow() - start_time).total_seconds() + logging.error( + f"✗ Failed to process alert {alert_id}: {e}", + exc_info=True + ) + self.metrics.increment_counter("workflow_errors") + self.metrics.record_histogram("workflow_error_time", processing_time) + + async def _perform_health_checks(self): + """Perform periodic health checks on all components.""" + # Check if it's time for health check (every 60 seconds) + if self._last_health_check is not None: + time_since_check = (datetime.utcnow() - self._last_health_check).total_seconds() + if time_since_check < 60: + return + + try: + logging.debug("Performing health checks...") + health_status = await self.health_check() + + if health_status: + logging.debug("✓ All components healthy") + else: + logging.warning("⚠ Some components unhealthy") + + self.metrics.set_gauge("health_check_status", 1 if health_status else 0) + self._last_health_check = datetime.utcnow() + + except Exception as e: + logging.error(f"Health check error: {e}", exc_info=True) + self.metrics.increment_counter("health_check_errors") + + async def shutdown(self): + """Shutdown the application gracefully. + + This method: + - Cancels all active workflow tasks + - Shuts down agents + - Closes storage connections + - Cleans up resources + """ + try: + logging.info("\n=== Shutting Down LG-SOTF Application ===") + + # Cancel active workflow tasks + if self._active_tasks: + task_count = len(self._active_tasks) + logging.info(f"Cancelling {task_count} active workflow tasks...") + + for task in self._active_tasks: + if not task.done(): + task.cancel() + + # Wait for tasks to complete with timeout + try: + await asyncio.wait_for( + asyncio.gather(*self._active_tasks, return_exceptions=True), + timeout=10.0 + ) + logging.info(f"✓ All {task_count} workflow tasks cancelled") + except asyncio.TimeoutError: + logging.warning(f"⚠ Some workflow tasks did not complete within timeout") + + # Log application shutdown + if self.audit_logger: + self.audit_logger.log_application_shutdown() + + # Shutdown agents + await self._shutdown_agents() + + # Close storage connections + await self._shutdown_storage() + + # Shutdown metrics collection + if self.metrics: + try: + self.metrics.shutdown() + logging.info("✓ Metrics collection stopped") + except Exception as e: + logging.warning(f"⚠ Error shutting down metrics: {e}") + + logging.info("=== LG-SOTF Application Shutdown Complete ===\n") + + except Exception as e: + logging.error(f"Error during shutdown: {e}", exc_info=True) + + async def _shutdown_agents(self): + """Shutdown all registered agents.""" + try: + from lg_sotf.agents.registry import agent_registry + + logging.info("Shutting down agents...") + await agent_registry.cleanup_all_agents() + logging.info("✓ All agents stopped") + + except Exception as e: + logging.warning(f"⚠ Error shutting down agents: {e}") + + async def _shutdown_storage(self): + """Shutdown storage connections.""" + storage_tasks = [] + + # Schedule PostgreSQL cleanup + if self.postgres_storage: + try: + storage_tasks.append( + asyncio.create_task(self.postgres_storage.close()) + ) + except Exception as e: + logging.warning(f"⚠ Error scheduling PostgreSQL close: {e}") + + # Schedule Redis cleanup + if self.redis_storage: + try: + storage_tasks.append( + asyncio.create_task(self.redis_storage.close()) + ) + except Exception as e: + logging.warning(f"⚠ Error scheduling Redis close: {e}") + + # Wait for storage cleanup with timeout + if storage_tasks: + try: + await asyncio.wait_for( + asyncio.gather(*storage_tasks, return_exceptions=True), + timeout=5.0 + ) + logging.info("✓ Storage connections closed") + except asyncio.TimeoutError: + logging.warning("⚠ Storage cleanup timed out") + except Exception as e: + logging.warning(f"⚠ Error during storage cleanup: {e}") + + async def health_check(self) -> bool: + """Perform comprehensive health check. + + Returns: + bool: True if all components healthy, False otherwise + """ + try: + health_results = { + 'config_manager': False, + 'state_manager': False, + 'workflow_engine': False, + 'postgres_storage': False, + 'redis_storage': False, + 'agents': False + } + + # Check configuration + if self.config_manager: + health_results['config_manager'] = True + + # Check state manager + if self.state_manager: + health_results['state_manager'] = True + + # Check workflow engine + if self.workflow_engine: + health_results['workflow_engine'] = True + + # Check PostgreSQL + if self.postgres_storage: + health_results['postgres_storage'] = await self.postgres_storage.health_check() + + # Check Redis + if self.redis_storage: + health_results['redis_storage'] = await self.redis_storage.health_check() + + # Check agents + try: + from lg_sotf.agents.registry import agent_registry + + # Check if any agent is healthy + if agent_registry.agent_exists("ingestion_instance"): + ingestion_agent = agent_registry.get_agent("ingestion_instance") + if hasattr(ingestion_agent, 'health_check'): + health_results['agents'] = await ingestion_agent.health_check() + else: + health_results['agents'] = ingestion_agent.initialized + + except Exception as e: + logging.debug(f"Agent health check error: {e}") + health_results['agents'] = False + + # Record component health metrics + if self.metrics: + for component, status in health_results.items(): + self.metrics.set_gauge(f"health_{component}", 1 if status else 0) + + # Calculate overall health + overall_health = all(health_results.values()) + + # Log unhealthy components + unhealthy = [comp for comp, status in health_results.items() if not status] + if unhealthy: + logging.debug(f"Unhealthy components: {', '.join(unhealthy)}") + + return overall_health + + except Exception as e: + logging.error(f"Health check failed: {e}", exc_info=True) + return False + + async def process_single_alert(self, alert_id: str, alert_data: dict) -> dict: + """Process a single alert through the workflow. + + This method is used for testing and manual alert processing. + + Args: + alert_id: Alert identifier + alert_data: Alert data dictionary + + Returns: + dict: Workflow result + + Raises: + LG_SOTFError: If workflow engine not initialized + """ + try: + if not self.workflow_engine: + raise LG_SOTFError("Workflow engine not initialized") + + logging.info(f"Processing single alert: {alert_id}") + + result = await self.workflow_engine.execute_workflow(alert_id, alert_data) + + logging.info(f"Alert {alert_id} processed successfully") + return result + + except Exception as e: + logging.error(f"Failed to process alert {alert_id}: {e}", exc_info=True) + raise + + def get_application_status(self) -> dict: + """Get comprehensive application status. + + Returns: + dict: Application status information + """ + try: + status = { + 'running': self.running, + 'initialized': self.initialized, + 'active_workflow_tasks': len(self._active_tasks), + 'last_ingestion_poll': self._last_ingestion_poll.isoformat() if self._last_ingestion_poll else None, + 'last_health_check': self._last_health_check.isoformat() if self._last_health_check else None, + 'components': { + 'config_manager': self.config_manager is not None, + 'state_manager': self.state_manager is not None, + 'workflow_engine': self.workflow_engine is not None, + 'audit_logger': self.audit_logger is not None, + 'metrics': self.metrics is not None + }, + 'storage': { + 'postgres': self.postgres_storage is not None, + 'redis': self.redis_storage is not None + } + } + + # Add agent status + try: + from lg_sotf.agents.registry import agent_registry + status['agents'] = agent_registry.get_registry_stats() + except Exception as e: + logging.warning(f"Error getting agent status: {e}") + status['agents'] = {'error': str(e)} + + return status + + except Exception as e: + logging.error(f"Error getting application status: {e}", exc_info=True) + return {'error': str(e)} + + +async def main(): + """Main entry point for LG-SOTF application.""" + import argparse + import json + + # Parse command line arguments + parser = argparse.ArgumentParser( + description="LG-SOTF: LangGraph SOC Triage & Orchestration Framework", + formatter_class=argparse.RawDescriptionHelpFormatter + ) + + parser.add_argument( + "--config", "-c", + type=str, + help="Path to configuration file" + ) + + parser.add_argument( + "--mode", "-m", + choices=["run", "health-check", "process-alert"], + default="run", + help="Application mode (default: run)" + ) + + parser.add_argument( + "--alert-id", + type=str, + help="Alert ID for process-alert mode" + ) + + parser.add_argument( + "--alert-data", + type=str, + help="Alert data JSON file for process-alert mode" + ) + + parser.add_argument( + "--log-level", "-l", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + default="INFO", + help="Logging level (default: INFO)" + ) + + parser.add_argument( + "--version", "-v", + action="version", + version="LG-SOTF 0.1.0" + ) + + args = parser.parse_args() + + # Setup logging + logging.basicConfig( + level=getattr(logging, args.log_level), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S" + ) + + # Create application instance + app = LG_SOTFApplication(config_path=args.config) + + try: + # Initialize application + await app.initialize() + + # Execute based on mode + if args.mode == "health-check": + # Perform health check + health_status = await app.health_check() + status_info = app.get_application_status() + + print("\n" + "=" * 60) + print("🏥 LG-SOTF APPLICATION HEALTH CHECK") + print("=" * 60) + print(f"\nOverall Health: {'✅ HEALTHY' if health_status else '❌ UNHEALTHY'}") + print(f"Running: {'✅ Yes' if status_info['running'] else '❌ No'}") + print(f"Initialized: {'✅ Yes' if status_info['initialized'] else '❌ No'}") + print(f"Active Workflow Tasks: {status_info['active_workflow_tasks']}") + + print("\n" + "-" * 60) + print("🔧 COMPONENTS") + print("-" * 60) + for component, status in status_info['components'].items(): + status_icon = '✅' if status else '❌' + print(f" {status_icon} {component}") + + print("\n" + "-" * 60) + print("🤖 AGENTS") + print("-" * 60) + if 'error' in status_info['agents']: + print(f" ❌ Error: {status_info['agents']['error']}") + else: + agents_info = status_info['agents'] + print(f" Types: {agents_info['agent_types_count']}") + print(f" Instances: {agents_info['agent_instances_count']}") + print(f" Initialized: {len(agents_info['initialized_agents'])}") + if agents_info['initialized_agents']: + for agent_name in agents_info['initialized_agents']: + print(f" ✅ {agent_name}") + + print("\n" + "-" * 60) + print("💾 STORAGE") + print("-" * 60) + for storage_type, status in status_info['storage'].items(): + status_icon = '✅' if status else '❌' + print(f" {status_icon} {storage_type}") + + print("\n" + "=" * 60 + "\n") + + sys.exit(0 if health_status else 1) + + elif args.mode == "process-alert": + # Process a single alert + if not args.alert_id: + print("❌ Error: Alert ID is required for process-alert mode") + print("Usage: python -m lg_sotf.main --mode process-alert --alert-id [--alert-data ]") + sys.exit(1) + + # Load alert data + if args.alert_data: + with open(args.alert_data, 'r') as f: + alert_data = json.load(f) + else: + # Use sample alert data + alert_data = { + "id": args.alert_id, + "source": "manual", + "timestamp": datetime.utcnow().isoformat(), + "severity": "high", + "title": "Manual test alert", + "description": "Test alert for manual processing" + } + + print(f"\n🔄 Processing alert: {args.alert_id}") + result = await app.process_single_alert(args.alert_id, alert_data) + + print("\n" + "=" * 60) + print("✅ ALERT PROCESSED SUCCESSFULLY") + print("=" * 60) + print(f"Alert ID: {args.alert_id}") + print(f"Final Status: {result.get('triage_status', 'unknown')}") + print(f"Confidence Score: {result.get('confidence_score', 0)}") + print(f"Processing Notes: {len(result.get('processing_notes', []))}") + print("=" * 60 + "\n") + + sys.exit(0) + + else: + # Run application in continuous mode + await app.run() + + except KeyboardInterrupt: + logging.info("\nApplication interrupted by user") + sys.exit(0) + except Exception as e: + logging.error(f"Application failed: {e}", exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/lg_sotf/utils/monitoring.py b/Backup/monitoring.py similarity index 100% rename from src/lg_sotf/utils/monitoring.py rename to Backup/monitoring.py diff --git a/src/lg_sotf/api/__init__.py b/src/lg_sotf/api/__init__.py new file mode 100644 index 00000000..b46e6013 --- /dev/null +++ b/src/lg_sotf/api/__init__.py @@ -0,0 +1,5 @@ +"""LG-SOTF API package.""" + +from .app import app, create_app + +__all__ = ["app", "create_app"] diff --git a/src/lg_sotf/api/app.py b/src/lg_sotf/api/app.py new file mode 100644 index 00000000..8a808b97 --- /dev/null +++ b/src/lg_sotf/api/app.py @@ -0,0 +1,213 @@ +"""FastAPI application factory for LG-SOTF API.""" + +import asyncio +import logging +from typing import Optional + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from lg_sotf.app_initializer import LG_SOTFApplication +from lg_sotf.api.utils.websocket import WebSocketManager +from lg_sotf.api.routers import ( + alerts, + correlations, + dashboard, + escalations, + ingestion, + metrics, + websocket, +) + +logger = logging.getLogger(__name__) + + +def create_app( + config_path: str = "configs/development.yaml", + setup_signal_handlers: bool = False, +) -> FastAPI: + """Create and configure the FastAPI application. + + Args: + config_path: Path to configuration file + setup_signal_handlers: Whether to setup signal handlers (False for uvicorn) + + Returns: + Configured FastAPI application instance + """ + # Configure logging + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + # Create FastAPI app + app = FastAPI( + title="LG-SOTF Dashboard API", + description="Production-grade SOC Dashboard API", + version="1.0.0", + docs_url="/api/docs", + redoc_url="/api/redoc" + ) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Create application instances (will be initialized in startup event) + lg_sotf_app = LG_SOTFApplication( + config_path=config_path, + setup_signal_handlers=setup_signal_handlers + ) + ws_manager = WebSocketManager() + + # Store in app state for dependency injection + app.state.lg_sotf_app = lg_sotf_app + app.state.ws_manager = ws_manager + + # Track background tasks for proper shutdown + app.state.background_tasks = [] + + # Register routers + app.include_router(metrics.router) + app.include_router(alerts.router) + app.include_router(ingestion.router) + app.include_router(dashboard.router) + app.include_router(correlations.router) + app.include_router(escalations.router) + app.include_router(websocket.router) + + # Startup event handler + @app.on_event("startup") + async def startup(): + """Initialize application on startup.""" + logger.info("Starting LG-SOTF API server...") + + # Initialize LG-SOTF application + await lg_sotf_app.initialize() + logger.info("LG-SOTF application initialized") + + # Start background tasks + _start_background_tasks(app) + logger.info("Background tasks started") + + logger.info("✅ LG-SOTF API server ready") + + # Shutdown event handler + @app.on_event("shutdown") + async def shutdown(): + """Cleanup on shutdown.""" + logger.info("🛑 Shutting down API server...") + + # Close all WebSocket connections + if ws_manager.active_connections: + logger.info(f"Closing {len(ws_manager.active_connections)} WebSocket connections...") + connections = list(ws_manager.active_connections.values()) + for ws in connections: + try: + await ws.close() + except Exception as e: + logger.error(f"Error closing WebSocket: {e}") + ws_manager.active_connections.clear() + logger.info("✓ WebSocket connections closed") + + # Cancel background tasks + if app.state.background_tasks: + logger.info(f"Cancelling {len(app.state.background_tasks)} background tasks...") + for task in app.state.background_tasks: + if not task.done(): + task.cancel() + + # Wait for tasks to complete with timeout + try: + await asyncio.wait_for( + asyncio.gather(*app.state.background_tasks, return_exceptions=True), + timeout=5.0 + ) + logger.info("✓ Background tasks cancelled") + except asyncio.TimeoutError: + logger.warning("⚠ Some background tasks did not complete within timeout") + + # Shutdown LG-SOTF application + await lg_sotf_app.shutdown() + logger.info("✅ Shutdown complete") + + return app + + +def _start_background_tasks(app: FastAPI): + """Start background monitoring and update tasks.""" + ws_manager: WebSocketManager = app.state.ws_manager + lg_sotf_app: LG_SOTFApplication = app.state.lg_sotf_app + + async def metrics_updater(): + """Periodically collect and broadcast system metrics.""" + while True: + try: + await asyncio.sleep(10) + + # Import here to avoid circular dependency + from lg_sotf.api.routers.metrics import _collect_system_metrics + + metrics = await _collect_system_metrics(lg_sotf_app) + + await ws_manager.broadcast({ + "type": "system_metrics", + "data": metrics.model_dump() + }, "system_metrics") + + except asyncio.CancelledError: + logger.info("Metrics updater cancelled") + break + except Exception as e: + logger.error(f"Metrics updater error: {e}") + + async def ingestion_monitor(): + """Monitor ingestion activity and broadcast updates.""" + while True: + try: + await asyncio.sleep(5) # Check every 5 seconds + + if not lg_sotf_app.workflow_engine: + continue + + ingestion_agent = ( + lg_sotf_app.workflow_engine.agents.get("ingestion_instance") or + lg_sotf_app.workflow_engine.agents.get("ingestion") + ) + + if not ingestion_agent: + continue + + # Broadcast ingestion stats + await ws_manager.broadcast({ + "type": "ingestion_stats", + "data": { + "total_ingested": ingestion_agent.ingestion_stats["total_ingested"], + "total_deduplicated": ingestion_agent.ingestion_stats["total_deduplicated"], + "total_errors": ingestion_agent.ingestion_stats["total_errors"], + "by_source": dict(ingestion_agent.ingestion_stats["by_source"]), + "enabled_sources": ingestion_agent.enabled_sources, + "last_poll": lg_sotf_app._last_ingestion_poll.isoformat() if lg_sotf_app._last_ingestion_poll else None + } + }, "ingestion_updates") + + except asyncio.CancelledError: + logger.info("Ingestion monitor cancelled") + break + except Exception as e: + logger.error(f"Ingestion monitor error: {e}") + + # Create and track background tasks + app.state.background_tasks.append(asyncio.create_task(ws_manager.heartbeat_loop())) + app.state.background_tasks.append(asyncio.create_task(metrics_updater())) + app.state.background_tasks.append(asyncio.create_task(ingestion_monitor())) + + +# Create app instance for uvicorn +app = create_app() diff --git a/src/lg_sotf/api/dependencies.py b/src/lg_sotf/api/dependencies.py new file mode 100644 index 00000000..99c49b89 --- /dev/null +++ b/src/lg_sotf/api/dependencies.py @@ -0,0 +1,30 @@ +"""Dependency injection for FastAPI routes.""" + +from fastapi import Depends, Request + +from lg_sotf.app_initializer import LG_SOTFApplication +from lg_sotf.api.utils.websocket import WebSocketManager + + +def get_lg_sotf_app(request: Request) -> LG_SOTFApplication: + """Get the LG-SOTF application instance. + + Args: + request: FastAPI request object + + Returns: + LG-SOTF application instance from app state + """ + return request.app.state.lg_sotf_app + + +def get_websocket_manager(request: Request) -> WebSocketManager: + """Get the WebSocket manager instance. + + Args: + request: FastAPI request object + + Returns: + WebSocket manager from app state + """ + return request.app.state.ws_manager diff --git a/src/lg_sotf/api/models/__init__.py b/src/lg_sotf/api/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/lg_sotf/api/models/alerts.py b/src/lg_sotf/api/models/alerts.py new file mode 100644 index 00000000..f87f959c --- /dev/null +++ b/src/lg_sotf/api/models/alerts.py @@ -0,0 +1,19 @@ +"""Alert-related Pydantic models.""" + +from typing import Any, Dict, Optional +from pydantic import BaseModel + + +class AlertRequest(BaseModel): + """Request model for processing a new alert.""" + alert_data: Dict[str, Any] + priority: Optional[str] = "normal" + + +class AlertResponse(BaseModel): + """Response model for alert processing.""" + alert_id: str + status: str + workflow_instance_id: str + processing_started: bool + estimated_completion: Optional[str] = None diff --git a/src/lg_sotf/api/models/ingestion.py b/src/lg_sotf/api/models/ingestion.py new file mode 100644 index 00000000..c6edb8c8 --- /dev/null +++ b/src/lg_sotf/api/models/ingestion.py @@ -0,0 +1,23 @@ +"""Ingestion-related Pydantic models.""" +from typing import Any, Dict, List, Optional +from pydantic import BaseModel + +class IngestionStatusResponse(BaseModel): + is_active: bool + last_poll_time: Optional[str] + next_poll_time: Optional[str] + polling_interval: int + sources_enabled: List[str] + sources_stats: Dict[str, Dict[str, int]] + total_ingested: int + total_deduplicated: int + total_errors: int + +class IngestionControlRequest(BaseModel): + action: str + sources: Optional[List[str]] = None + +class SourceConfigRequest(BaseModel): + source_name: str + enabled: bool + config: Optional[Dict[str, Any]] = None diff --git a/src/lg_sotf/api/models/metrics.py b/src/lg_sotf/api/models/metrics.py new file mode 100644 index 00000000..567ffe19 --- /dev/null +++ b/src/lg_sotf/api/models/metrics.py @@ -0,0 +1,29 @@ +"""Metrics and health-related Pydantic models.""" +from typing import Any, Dict, List, Optional +from pydantic import BaseModel + +class MetricsResponse(BaseModel): + timestamp: str + alerts_processed_today: int + alerts_in_progress: int + average_processing_time: float + success_rate: float + agent_health: Dict[str, bool] + system_health: bool + +class DashboardStatsResponse(BaseModel): + total_alerts_today: int + high_priority_alerts: int + alerts_by_status: Dict[str, int] + alerts_by_severity: Dict[str, int] + top_threat_indicators: List[Dict[str, Any]] + recent_escalations: List[Dict[str, Any]] + processing_time_avg: float + +class AgentStatusResponse(BaseModel): + agent_name: str + status: str + last_execution: Optional[str] + success_rate: float + average_execution_time: float + error_count: int diff --git a/src/lg_sotf/api/models/websocket.py b/src/lg_sotf/api/models/websocket.py new file mode 100644 index 00000000..e69de29b diff --git a/src/lg_sotf/api/models/workflows.py b/src/lg_sotf/api/models/workflows.py new file mode 100644 index 00000000..1ba272fe --- /dev/null +++ b/src/lg_sotf/api/models/workflows.py @@ -0,0 +1,41 @@ +"""Workflow-related Pydantic models.""" +from typing import Any, Dict, List, Optional +from pydantic import BaseModel + +class WorkflowStatusResponse(BaseModel): + alert_id: str + workflow_instance_id: str + current_node: str + triage_status: str + confidence_score: int + threat_score: Optional[int] = 0 + processing_notes: List[str] + enriched_data: Optional[Dict[str, Any]] = {} + escalation_info: Optional[Dict[str, Any]] = None + response_execution: Optional[Dict[str, Any]] = None + fp_indicators: Optional[List[str]] = [] + tp_indicators: Optional[List[str]] = [] + correlations: Optional[List[Dict[str, Any]]] = [] + correlation_score: Optional[int] = 0 + analysis_conclusion: Optional[str] = None + recommended_actions: Optional[List[str]] = [] + last_updated: str + progress_percentage: int + +class CorrelationResponse(BaseModel): + alert_id: str + correlations: List[Dict[str, Any]] + correlation_score: int + attack_campaign_indicators: List[str] + threat_actor_patterns: List[str] + +class FeedbackRequest(BaseModel): + analyst_username: str + decision: str + confidence: int + notes: str + actions_taken: Optional[List[str]] = None + actions_recommended: Optional[List[str]] = None + triage_correct: Optional[bool] = None + correlation_helpful: Optional[bool] = None + analysis_accurate: Optional[bool] = None diff --git a/src/lg_sotf/api/routers/__init__.py b/src/lg_sotf/api/routers/__init__.py new file mode 100644 index 00000000..22b044c0 --- /dev/null +++ b/src/lg_sotf/api/routers/__init__.py @@ -0,0 +1,21 @@ +"""API routers for LG-SOTF Dashboard.""" + +from . import ( + alerts, + correlations, + dashboard, + escalations, + ingestion, + metrics, + websocket, +) + +__all__ = [ + "alerts", + "correlations", + "dashboard", + "escalations", + "ingestion", + "metrics", + "websocket", +] diff --git a/src/lg_sotf/api/routers/alerts.py b/src/lg_sotf/api/routers/alerts.py new file mode 100644 index 00000000..aa2a90e9 --- /dev/null +++ b/src/lg_sotf/api/routers/alerts.py @@ -0,0 +1,520 @@ +"""Alert processing and status endpoints.""" + +import asyncio +import json +import logging +import time +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional +from uuid import uuid4 + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException + +from lg_sotf.app_initializer import LG_SOTFApplication +from lg_sotf.api.dependencies import get_lg_sotf_app, get_websocket_manager +from lg_sotf.api.models.alerts import AlertRequest, AlertResponse +from lg_sotf.api.models.workflows import ( + WorkflowStatusResponse, + CorrelationResponse, + FeedbackRequest, +) +from lg_sotf.api.utils.websocket import WebSocketManager + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/alerts", tags=["alerts"]) + + +@router.post("/process", response_model=AlertResponse) +async def process_alert( + alert_request: AlertRequest, + background_tasks: BackgroundTasks, + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app), + ws_manager: WebSocketManager = Depends(get_websocket_manager), +): + """Submit a new alert for processing through the workflow.""" + try: + alert_id = str(uuid4()) + + background_tasks.add_task( + _process_alert_background, + alert_id, + alert_request.alert_data, + lg_sotf_app, + ws_manager + ) + + await ws_manager.broadcast({ + "type": "new_alert", + "alert_id": alert_id, + "severity": alert_request.alert_data.get("severity", "unknown"), + "timestamp": datetime.utcnow().isoformat() + }, "new_alerts") + + return AlertResponse( + alert_id=alert_id, + status="processing", + workflow_instance_id=f"{alert_id}_{int(time.time())}", + processing_started=True, + estimated_completion=(datetime.utcnow() + timedelta(minutes=2)).isoformat() + ) + + except Exception as e: + logger.error(f"Alert processing failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/{alert_id}/status", response_model=WorkflowStatusResponse) +async def get_alert_status( + alert_id: str, + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app), +): + """Get current workflow status for an alert.""" + try: + state = await _get_alert_state(alert_id, lg_sotf_app) + + if not state: + raise HTTPException(status_code=404, detail="Alert not found") + + # Extract enriched data + enriched_data = state.get("enriched_data", {}) + + return WorkflowStatusResponse( + alert_id=alert_id, + workflow_instance_id=state.get("workflow_instance_id", ""), + current_node=state.get("current_node", "unknown"), + triage_status=state.get("triage_status", "unknown"), + confidence_score=state.get("confidence_score", 0), + threat_score=state.get("threat_score", 0), + processing_notes=state.get("processing_notes", []), + enriched_data=enriched_data, + escalation_info=enriched_data.get("escalation_info"), + response_execution=state.get("response_execution"), + fp_indicators=state.get("fp_indicators", []), + tp_indicators=state.get("tp_indicators", []), + correlations=state.get("correlations", []), + correlation_score=state.get("correlation_score", 0), + analysis_conclusion=state.get("analysis_conclusion"), + recommended_actions=state.get("recommended_actions", []), + last_updated=state.get("last_updated", datetime.utcnow().isoformat()), + progress_percentage=_calculate_progress(state) + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Status retrieval failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("") +async def get_alerts( + limit: int = 50, + status: Optional[str] = None, + hours: int = 24, + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app), +): + """Query recent alerts with optional filtering.""" + try: + alerts = await _query_recent_alerts(limit, status, hours, lg_sotf_app) + return alerts + + except Exception as e: + logger.error(f"Alert retrieval failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/{alert_id}/correlations", response_model=CorrelationResponse) +async def get_alert_correlations( + alert_id: str, + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app), +): + """Get correlations for a specific alert.""" + try: + correlations = await _get_alert_correlations(alert_id, lg_sotf_app) + return correlations + + except Exception as e: + logger.error(f"Correlation retrieval failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# Helper functions + +async def _process_alert_background( + alert_id: str, + alert_data: Dict[str, Any], + lg_sotf_app: LG_SOTFApplication, + ws_manager: WebSocketManager, +): + """Background task to process an alert through the workflow.""" + try: + await ws_manager.broadcast({ + "type": "ingestion_event", + "event": "alert_ingested", + "alert_id": alert_id, + "source": alert_data.get("source", "unknown"), + "severity": alert_data.get("severity", "unknown"), + "timestamp": datetime.utcnow().isoformat() + }, "ingestion_updates") + + await ws_manager.broadcast({ + "type": "alert_update", + "alert_id": alert_id, + "status": "processing", + "progress": 10 + }, "alert_updates") + + result = await lg_sotf_app.process_single_alert(alert_id, alert_data) + + await ws_manager.broadcast({ + "type": "alert_update", + "alert_id": alert_id, + "status": "completed", + "progress": 100, + "result": result + }, "alert_updates") + + except Exception as e: + logger.error(f"Background processing failed: {e}") + + await ws_manager.broadcast({ + "type": "alert_update", + "alert_id": alert_id, + "status": "failed", + "error": str(e) + }, "alert_updates") + + +async def _get_alert_state(alert_id: str, lg_sotf_app: LG_SOTFApplication) -> Optional[Dict[str, Any]]: + """Retrieve alert state from PostgreSQL.""" + try: + storage = lg_sotf_app.postgres_storage + + query = """ + SELECT state_data, version, created_at + FROM states + WHERE alert_id = $1 + ORDER BY version DESC + LIMIT 1 + """ + + async with storage.pool.acquire() as conn: + result = await conn.fetchrow(query, alert_id) + + if not result: + return None + + state_json = result['state_data'] + state_data = json.loads(state_json) if isinstance(state_json, str) else state_json + + # Debug logging + logger.info(f"State data for {alert_id}: triage_status={state_data.get('triage_status')}, confidence={state_data.get('confidence_score')}, threat_score={state_data.get('threat_score', 0)}") + + # Extract metadata + metadata = state_data.get("metadata", {}) + + # Handle enum conversion for triage_status + triage_status = state_data.get("triage_status", "unknown") + if isinstance(triage_status, dict) and "_value_" in triage_status: + # Pydantic enum serialization + triage_status = triage_status["_value_"] + elif hasattr(triage_status, 'value'): + # Python enum + triage_status = triage_status.value + else: + # String - clean it up + triage_status = str(triage_status).replace("TriageStatus.", "").replace("_", " ").lower() + + # Build the response with correct field extraction + enriched_data = state_data.get("enriched_data", {}) + llm_insights = enriched_data.get("llm_insights", {}) + + merged_state = { + "alert_id": alert_id, + "workflow_instance_id": state_data.get("workflow_instance_id", ""), + "current_node": state_data.get("current_node", "unknown"), + "triage_status": triage_status, + "confidence_score": int(state_data.get("confidence_score", 0)), + "threat_score": int(state_data.get("threat_score", 0)), + "processing_notes": metadata.get("processing_notes", []), + "enriched_data": enriched_data, + "correlations": state_data.get("correlations", []), + "correlation_score": int(state_data.get("correlation_score", 0)), + "analysis_conclusion": state_data.get("analysis_conclusion", ""), + "recommended_actions": state_data.get("recommended_actions", []), + # Extract FP/TP indicators to top level for easy access + "fp_indicators": state_data.get("fp_indicators", llm_insights.get("fp_indicators", [])), + "tp_indicators": state_data.get("tp_indicators", llm_insights.get("tp_indicators", [])), + # Extract response execution to top level + "response_execution": enriched_data.get("response_execution"), + "last_updated": result['created_at'].isoformat(), + "raw_alert": state_data.get("raw_alert", {}) + } + + logger.info(f"Returning merged state: confidence={merged_state['confidence_score']}, threat_score={merged_state['threat_score']}, status={merged_state['triage_status']}") + + return merged_state + + except Exception as e: + logger.error(f"State retrieval error: {e}", exc_info=True) + return None + + +async def _query_recent_alerts( + limit: int, + status: Optional[str], + hours: int, + lg_sotf_app: LG_SOTFApplication +) -> List[Dict]: + """Query recent alerts from PostgreSQL with filtering.""" + try: + storage = lg_sotf_app.postgres_storage + cutoff_time = datetime.utcnow() - timedelta(hours=hours) + + if status: + query = """ + SELECT DISTINCT ON (alert_id) + alert_id, + state_data->>'workflow_instance_id' as workflow_instance_id, + state_data->>'triage_status' as status, + state_data->>'current_node' as current_node, + (state_data->>'confidence_score')::int as confidence_score, + COALESCE((state_data->>'threat_score')::int, 0) as threat_score, + state_data->'raw_alert'->>'severity' as severity, + state_data->'raw_alert'->>'description' as description, + state_data->'raw_alert'->>'title' as title, + created_at + FROM states + WHERE created_at >= $1 + AND state_data->>'triage_status' = $2 + ORDER BY alert_id, version DESC + LIMIT $3 + """ + + async with storage.pool.acquire() as conn: + results = await conn.fetch(query, cutoff_time, status, limit) + else: + query = """ + SELECT DISTINCT ON (alert_id) + alert_id, + state_data->>'workflow_instance_id' as workflow_instance_id, + state_data->>'triage_status' as status, + state_data->>'current_node' as current_node, + (state_data->>'confidence_score')::int as confidence_score, + COALESCE((state_data->>'threat_score')::int, 0) as threat_score, + state_data->'raw_alert'->>'severity' as severity, + state_data->'raw_alert'->>'description' as description, + state_data->'raw_alert'->>'title' as title, + created_at + FROM states + WHERE created_at >= $1 + ORDER BY alert_id, version DESC + LIMIT $2 + """ + + async with storage.pool.acquire() as conn: + results = await conn.fetch(query, cutoff_time, limit) + + alerts = [] + for row in results: + alerts.append({ + "alert_id": row['alert_id'], + "workflow_instance_id": row['workflow_instance_id'], + "status": row['status'], + "current_node": row['current_node'], + "confidence_score": row['confidence_score'], + "threat_score": row['threat_score'], + "severity": row['severity'] or 'medium', + "description": row['description'] or row['title'] or 'Security alert', + "created_at": row['created_at'].isoformat() + }) + + return alerts + + except Exception as e: + logger.error(f"Query recent alerts error: {e}") + return [] + + +async def _get_alert_correlations(alert_id: str, lg_sotf_app: LG_SOTFApplication) -> CorrelationResponse: + """Get correlations for a specific alert with Redis integration.""" + try: + # Get the alert state + state = await _get_alert_state(alert_id, lg_sotf_app) + + if not state: + return CorrelationResponse( + alert_id=alert_id, + correlations=[], + correlation_score=0, + attack_campaign_indicators=[], + threat_actor_patterns=[] + ) + + # Extract correlation data from metadata (workflow-generated correlations) + metadata = state.get("metadata", {}) if isinstance(state, dict) else {} + base_correlations = metadata.get("correlations", []) + + # Get Redis-based real-time correlations + redis_correlations = await _get_redis_correlations(alert_id, lg_sotf_app) + + # Merge both sources (deduplicate by indicator) + all_correlations = base_correlations.copy() + existing_indicators = {c.get('indicator') for c in base_correlations} + + for redis_corr in redis_correlations: + if redis_corr.get('indicator') not in existing_indicators: + all_correlations.append(redis_corr) + + return CorrelationResponse( + alert_id=alert_id, + correlations=all_correlations, + correlation_score=metadata.get("correlation_score", 0), + attack_campaign_indicators=metadata.get("attack_campaign_indicators", []), + threat_actor_patterns=metadata.get("threat_actor_patterns", []) + ) + + except Exception as e: + logger.error(f"Correlation retrieval error: {e}") + return CorrelationResponse( + alert_id=alert_id, + correlations=[], + correlation_score=0, + attack_campaign_indicators=[], + threat_actor_patterns=[] + ) + + +async def _get_redis_correlations(alert_id: str, lg_sotf_app: LG_SOTFApplication) -> List[Dict[str, Any]]: + """Get real-time correlations from Redis.""" + correlations = [] + + try: + redis_storage = lg_sotf_app.redis_storage + if not redis_storage or not redis_storage.redis_client: + return correlations + + # Get all indicators for this alert + indicators_key = f"alert:{alert_id}:indicators" + indicators_raw = await redis_storage.redis_client.smembers(indicators_key) + + if not indicators_raw: + return correlations + + # Decode and parse indicators + indicators = [] + for ind in indicators_raw: + ind_str = ind.decode('utf-8') if isinstance(ind, bytes) else str(ind) + parts = ind_str.split(":", 1) + if len(parts) == 2: + indicators.append({ + 'type': parts[0], + 'value': parts[1] + }) + + # For each indicator, find related alerts and co-occurrences + seen_alerts = set() + + for indicator in indicators: + indicator_type = indicator['type'] + indicator_value = indicator['value'] + + # 1. Get related alerts sharing this indicator + alerts_key = f"indicator:{indicator_type}:{indicator_value}:alerts" + related_alerts_raw = await redis_storage.redis_client.smembers(alerts_key) + + related_alerts = [] + for alert_raw in related_alerts_raw: + related_alert_id = alert_raw.decode('utf-8') if isinstance(alert_raw, bytes) else str(alert_raw) + if related_alert_id != alert_id and related_alert_id not in seen_alerts: + related_alerts.append(related_alert_id) + seen_alerts.add(related_alert_id) + + if related_alerts: + correlations.append({ + "type": "shared_indicator", + "indicator": f"{indicator_type}:{indicator_value}", + "indicator_type": indicator_type, + "indicator_value": indicator_value, + "description": f"Indicator {indicator_type}={indicator_value} shared with {len(related_alerts)} other alert(s)", + "confidence": min(90, 50 + len(related_alerts) * 10), + "weight": 30, + "threat_level": "high" if len(related_alerts) > 3 else "medium" if len(related_alerts) > 1 else "low", + "related_alerts": related_alerts[:5], # Top 5 + "total_related": len(related_alerts) + }) + + # 2. Get co-occurring indicators + cooccur_key = f"indicator:{indicator_type}:{indicator_value}:cooccur" + cooccur_raw = await redis_storage.redis_client.zrevrange( + cooccur_key, 0, 4, withscores=True # Top 5 + ) + + if cooccur_raw: + cooccurring = [] + for i in range(0, len(cooccur_raw), 2): + member = cooccur_raw[i].decode('utf-8') if isinstance(cooccur_raw[i], bytes) else str(cooccur_raw[i]) + score = int(cooccur_raw[i + 1]) if i + 1 < len(cooccur_raw) else 0 + cooccurring.append({ + 'indicator': member, + 'count': score + }) + + if cooccurring: + correlations.append({ + "type": "co_occurrence", + "indicator": f"{indicator_type}:{indicator_value}", + "description": f"Frequently co-occurs with {len(cooccurring)} other indicator(s)", + "confidence": 70, + "weight": 20, + "threat_level": "medium", + "co_occurring_indicators": cooccurring + }) + + # 3. Get indicator metadata (count, first/last seen) + metadata_key = f"indicator:{indicator_type}:{indicator_value}" + metadata_raw = await redis_storage.redis_client.hgetall(metadata_key) + + if metadata_raw: + metadata = {} + for key, value in metadata_raw.items(): + key_str = key.decode('utf-8') if isinstance(key, bytes) else str(key) + value_str = value.decode('utf-8') if isinstance(value, bytes) else str(value) + metadata[key_str] = value_str + + count = int(metadata.get('count', 0)) + if count > 1: + correlations.append({ + "type": "frequency", + "indicator": f"{indicator_type}:{indicator_value}", + "description": f"Seen {count} times (first: {metadata.get('first_seen', 'unknown')}, last: {metadata.get('last_seen', 'unknown')})", + "confidence": min(80, 40 + count * 10), + "weight": 15, + "threat_level": "high" if count > 5 else "medium" if count > 2 else "low", + "frequency_count": count, + "first_seen": metadata.get('first_seen'), + "last_seen": metadata.get('last_seen') + }) + + return correlations + + except Exception as e: + logger.error(f"Redis correlation retrieval error: {e}", exc_info=True) + return [] + + +def _calculate_progress(state: Dict[str, Any]) -> int: + """Calculate workflow progress percentage based on current node.""" + node_progress = { + "ingestion": 10, + "triage": 30, + "correlation": 50, + "analysis": 70, + "human_loop": 85, + "response": 95, + "close": 100 + } + + current_node = state.get("current_node", "ingestion") + return node_progress.get(current_node, 0) diff --git a/src/lg_sotf/api/routers/correlations.py b/src/lg_sotf/api/routers/correlations.py new file mode 100644 index 00000000..0ad922e4 --- /dev/null +++ b/src/lg_sotf/api/routers/correlations.py @@ -0,0 +1,252 @@ +"""Correlation metrics and network analysis endpoints.""" + +import json +import logging +from datetime import datetime +from typing import Any, Dict + +from fastapi import APIRouter, Depends, HTTPException + +from lg_sotf.app_initializer import LG_SOTFApplication +from lg_sotf.api.dependencies import get_lg_sotf_app + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/correlations", tags=["correlations"]) + + +@router.get("/metrics") +async def get_correlation_metrics( + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app) +): + """Get real-time correlation metrics from Redis.""" + try: + metrics = await _get_correlation_metrics(lg_sotf_app) + return metrics + + except Exception as e: + logger.error(f"Correlation metrics retrieval failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/network") +async def get_correlation_network( + limit: int = 50, + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app) +): + """Get correlation network graph data showing alert relationships.""" + try: + network_data = await _get_correlation_network(limit, lg_sotf_app) + return network_data + + except Exception as e: + logger.error(f"Correlation network retrieval failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +# Helper functions + +async def _get_correlation_metrics(lg_sotf_app: LG_SOTFApplication) -> Dict[str, Any]: + """Get comprehensive correlation metrics from Redis.""" + try: + redis_storage = lg_sotf_app.redis_storage + if not redis_storage or not redis_storage.redis_client: + return { + "total_indicators": 0, + "total_alerts": 0, + "correlation_patterns": {}, + "timestamp": datetime.utcnow().isoformat() + } + + # Count total unique indicators + indicator_count = 0 + alert_ids = set() + indicator_types = {} + + async for key in redis_storage.redis_client.scan_iter(match="indicator:*:alerts"): + indicator_count += 1 + key_str = key.decode('utf-8') if isinstance(key, bytes) else str(key) + + # Extract indicator type + parts = key_str.split(":") + if len(parts) >= 2: + indicator_type = parts[1] + indicator_types[indicator_type] = indicator_types.get(indicator_type, 0) + 1 + + # Get alerts for this indicator + alerts_raw = await redis_storage.redis_client.smembers(key) + for alert in alerts_raw: + alert_str = alert.decode('utf-8') if isinstance(alert, bytes) else str(alert) + alert_ids.add(alert_str) + + # Count total unique alerts + total_alerts = len(alert_ids) + + # Calculate correlation patterns + shared_indicators = 0 + async for key in redis_storage.redis_client.scan_iter(match="indicator:*:alerts"): + count = await redis_storage.redis_client.scard(key) + if count > 1: + shared_indicators += 1 + + return { + "total_indicators": indicator_count, + "total_alerts": total_alerts, + "shared_indicators": shared_indicators, + "correlation_rate": round((shared_indicators / indicator_count * 100) if indicator_count > 0 else 0, 2), + "indicator_types": indicator_types, + "avg_indicators_per_alert": round(indicator_count / total_alerts, 2) if total_alerts > 0 else 0, + "timestamp": datetime.utcnow().isoformat() + } + + except Exception as e: + logger.error(f"Correlation metrics error: {e}", exc_info=True) + return { + "total_indicators": 0, + "total_alerts": 0, + "correlation_patterns": {}, + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + } + + +async def _get_correlation_network(limit: int, lg_sotf_app: LG_SOTFApplication) -> Dict[str, Any]: + """Build correlation network graph showing alert relationships.""" + try: + redis_storage = lg_sotf_app.redis_storage + if not redis_storage or not redis_storage.redis_client: + return {"nodes": [], "edges": []} + + nodes = [] # Alerts + edges = [] # Relationships via shared indicators + alert_indicators = {} # Track which indicators each alert has + + # Get all alerts + alert_keys = [] + async for key in redis_storage.redis_client.scan_iter(match="alert:*:indicators"): + key_str = key.decode('utf-8') if isinstance(key, bytes) else str(key) + alert_id = key_str.split(":")[1] + alert_keys.append((alert_id, key_str)) + + # Limit to most recent alerts + alert_keys = alert_keys[:limit] + + # Build nodes and collect indicators + for alert_id, key in alert_keys: + # Get indicators for this alert + indicators_raw = await redis_storage.redis_client.smembers(key) + indicators = [] + for ind in indicators_raw: + ind_str = ind.decode('utf-8') if isinstance(ind, bytes) else str(ind) + indicators.append(ind_str) + + alert_indicators[alert_id] = set(indicators) + + # Get alert metadata from PostgreSQL if available + alert_state = await _get_alert_state(alert_id, lg_sotf_app) + severity = "unknown" + status = "unknown" + + if alert_state: + raw_alert = alert_state.get("raw_alert", {}) if isinstance(alert_state, dict) else {} + if isinstance(raw_alert, dict): + severity = raw_alert.get("severity", "unknown") + status = alert_state.get("triage_status", "unknown") + + nodes.append({ + "id": alert_id, + "type": "alert", + "label": alert_id, + "severity": severity, + "status": status, + "indicator_count": len(indicators) + }) + + # Build edges (connections via shared indicators) + edge_id = 0 + processed_pairs = set() + + for alert1_id, indicators1 in alert_indicators.items(): + for alert2_id, indicators2 in alert_indicators.items(): + if alert1_id >= alert2_id: # Skip self and duplicates + continue + + # Check if already processed + pair = tuple(sorted([alert1_id, alert2_id])) + if pair in processed_pairs: + continue + + # Find shared indicators + shared = indicators1.intersection(indicators2) + + if shared: + processed_pairs.add(pair) + edges.append({ + "id": f"edge_{edge_id}", + "source": alert1_id, + "target": alert2_id, + "shared_indicators": list(shared), + "shared_count": len(shared), + "weight": len(shared) # For graph visualization + }) + edge_id += 1 + + return { + "nodes": nodes, + "edges": edges, + "summary": { + "total_alerts": len(nodes), + "total_connections": len(edges), + "timestamp": datetime.utcnow().isoformat() + } + } + + except Exception as e: + logger.error(f"Correlation network error: {e}", exc_info=True) + return { + "nodes": [], + "edges": [], + "error": str(e) + } + + +async def _get_alert_state(alert_id: str, lg_sotf_app: LG_SOTFApplication) -> Dict[str, Any]: + """Retrieve alert state from PostgreSQL.""" + try: + storage = lg_sotf_app.postgres_storage + + query = """ + SELECT state_data, version, created_at + FROM states + WHERE alert_id = $1 + ORDER BY version DESC + LIMIT 1 + """ + + async with storage.pool.acquire() as conn: + result = await conn.fetchrow(query, alert_id) + + if not result: + return {} + + state_json = result['state_data'] + state_data = json.loads(state_json) if isinstance(state_json, str) else state_json + + # Handle enum conversion for triage_status + triage_status = state_data.get("triage_status", "unknown") + if isinstance(triage_status, dict) and "_value_" in triage_status: + triage_status = triage_status["_value_"] + elif hasattr(triage_status, 'value'): + triage_status = triage_status.value + else: + triage_status = str(triage_status).replace("TriageStatus.", "").replace("_", " ").lower() + + return { + "alert_id": alert_id, + "triage_status": triage_status, + "raw_alert": state_data.get("raw_alert", {}) + } + + except Exception as e: + logger.error(f"Alert state retrieval error: {e}") + return {} diff --git a/src/lg_sotf/api/routers/dashboard.py b/src/lg_sotf/api/routers/dashboard.py new file mode 100644 index 00000000..21475c9d --- /dev/null +++ b/src/lg_sotf/api/routers/dashboard.py @@ -0,0 +1,191 @@ +"""Dashboard statistics and analytics endpoints.""" + +import logging +from datetime import datetime, timedelta +from typing import Any, Dict, List + +from fastapi import APIRouter, Depends, HTTPException + +from lg_sotf.app_initializer import LG_SOTFApplication +from lg_sotf.api.dependencies import get_lg_sotf_app +from lg_sotf.api.models.metrics import DashboardStatsResponse + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/dashboard", tags=["dashboard"]) + + +@router.get("/stats", response_model=DashboardStatsResponse) +async def get_dashboard_stats( + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app) +): + """Get comprehensive dashboard statistics.""" + try: + stats = await _get_dashboard_statistics(lg_sotf_app) + return stats + + except Exception as e: + logger.error(f"Dashboard stats retrieval failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# Helper functions + +async def _get_dashboard_statistics(lg_sotf_app: LG_SOTFApplication) -> DashboardStatsResponse: + """Collect dashboard statistics from PostgreSQL and Redis.""" + try: + storage = lg_sotf_app.postgres_storage + cutoff_time = datetime.utcnow() - timedelta(hours=24) + + async with storage.pool.acquire() as conn: + total_alerts = await conn.fetchval(""" + SELECT COUNT(DISTINCT alert_id) + FROM states + WHERE created_at >= $1 + """, cutoff_time) + + high_priority = await conn.fetchval(""" + SELECT COUNT(DISTINCT s1.alert_id) + FROM states s1 + INNER JOIN ( + SELECT alert_id, MAX(version) as max_version + FROM states + WHERE created_at >= $1 + GROUP BY alert_id + ) s2 ON s1.alert_id = s2.alert_id AND s1.version = s2.max_version + WHERE (state_data->>'priority_level')::int <= 2 + """, cutoff_time) + + status_results = await conn.fetch(""" + SELECT + state_data->>'triage_status' as status, + COUNT(DISTINCT s1.alert_id) as count + FROM states s1 + INNER JOIN ( + SELECT alert_id, MAX(version) as max_version + FROM states + WHERE created_at >= $1 + GROUP BY alert_id + ) s2 ON s1.alert_id = s2.alert_id AND s1.version = s2.max_version + GROUP BY state_data->>'triage_status' + """, cutoff_time) + + severity_results = await conn.fetch(""" + SELECT + state_data->'raw_alert'->>'severity' as severity, + COUNT(DISTINCT alert_id) as count + FROM states + WHERE created_at >= $1 + GROUP BY state_data->'raw_alert'->>'severity' + """, cutoff_time) + + alerts_by_status = {row['status']: row['count'] for row in status_results if row['status']} + alerts_by_severity = {row['severity']: row['count'] for row in severity_results if row['severity']} + + # Get top threat indicators from Redis + top_threat_indicators = await _get_top_threat_indicators(lg_sotf_app) + + return DashboardStatsResponse( + total_alerts_today=total_alerts or 0, + high_priority_alerts=high_priority or 0, + alerts_by_status=alerts_by_status, + alerts_by_severity=alerts_by_severity, + top_threat_indicators=top_threat_indicators, + recent_escalations=[], + processing_time_avg=125.0 + ) + + except Exception as e: + logger.error(f"Dashboard statistics error: {e}") + return DashboardStatsResponse( + total_alerts_today=0, + high_priority_alerts=0, + alerts_by_status={}, + alerts_by_severity={}, + top_threat_indicators=[], + recent_escalations=[], + processing_time_avg=0.0 + ) + + +async def _get_top_threat_indicators(lg_sotf_app: LG_SOTFApplication, limit: int = 10) -> List[Dict[str, Any]]: + """Get top threat indicators from Redis based on frequency and correlation.""" + try: + redis_storage = lg_sotf_app.redis_storage + if not redis_storage or not redis_storage.redis_client: + return [] + + indicators_data = [] + + # Scan all indicator alert sets + async for key in redis_storage.redis_client.scan_iter(match="indicator:*:alerts"): + key_str = key.decode('utf-8') if isinstance(key, bytes) else str(key) + + # Extract indicator type and value from key + # Format: indicator:{type}:{value}:alerts + parts = key_str.split(":") + if len(parts) >= 3: + indicator_type = parts[1] + indicator_value = ":".join(parts[2:-1]) # Handle values with colons + + # Get count of alerts with this indicator + alert_count = await redis_storage.redis_client.scard(key) + + if alert_count > 0: + # Get metadata for additional context + metadata_key = f"indicator:{indicator_type}:{indicator_value}" + metadata_raw = await redis_storage.redis_client.hgetall(metadata_key) + + metadata = {} + if metadata_raw: + for k, v in metadata_raw.items(): + k_str = k.decode('utf-8') if isinstance(k, bytes) else str(k) + v_str = v.decode('utf-8') if isinstance(v, bytes) else str(v) + metadata[k_str] = v_str + + # Calculate threat score based on frequency and recency + frequency_count = int(metadata.get('count', alert_count)) + threat_score = alert_count * 10 # Base score from alert count + + # Boost score for high-risk indicator types + risk_multipliers = { + 'file_hash': 2.0, + 'destination_ip': 1.5, + 'user': 1.3, + 'username': 1.3, + 'source_ip': 1.2 + } + threat_score *= risk_multipliers.get(indicator_type, 1.0) + + indicators_data.append({ + 'indicator': f"{indicator_type}:{indicator_value}", + 'indicator_type': indicator_type, + 'indicator_value': indicator_value, + 'count': alert_count, + 'frequency': frequency_count, + 'threat_score': int(threat_score), + 'first_seen': metadata.get('first_seen'), + 'last_seen': metadata.get('last_seen') + }) + + # Sort by threat score descending + indicators_data.sort(key=lambda x: x['threat_score'], reverse=True) + + # Return top N with formatted output + top_indicators = [] + for ind in indicators_data[:limit]: + top_indicators.append({ + 'indicator': ind['indicator'], + 'indicator_type': ind['indicator_type'], + 'indicator_value': ind['indicator_value'], + 'count': ind['count'], + 'threat_level': 'high' if ind['threat_score'] > 50 else 'medium' if ind['threat_score'] > 20 else 'low', + 'first_seen': ind['first_seen'], + 'last_seen': ind['last_seen'] + }) + + return top_indicators + + except Exception as e: + logger.error(f"Top threat indicators error: {e}", exc_info=True) + return [] diff --git a/src/lg_sotf/api/routers/escalations.py b/src/lg_sotf/api/routers/escalations.py new file mode 100644 index 00000000..068843a5 --- /dev/null +++ b/src/lg_sotf/api/routers/escalations.py @@ -0,0 +1,129 @@ +"""Escalation and human-in-the-loop endpoints.""" + +import logging +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException + +from lg_sotf.app_initializer import LG_SOTFApplication +from lg_sotf.api.dependencies import get_lg_sotf_app +from lg_sotf.api.models.workflows import FeedbackRequest + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/escalations", tags=["escalations"]) + + +@router.get("") +async def get_pending_escalations( + level: Optional[str] = None, + limit: int = 50, + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app) +): + """Get pending escalations from queue.""" + try: + human_loop_agent = lg_sotf_app.workflow_engine.agents.get("human_loop") + if not human_loop_agent: + raise HTTPException(status_code=503, detail="Human loop agent not available") + + escalations = await human_loop_agent.get_pending_escalations(level=level, limit=limit) + return {"escalations": escalations, "count": len(escalations)} + + except HTTPException: + raise + except Exception as e: + logger.error(f"Escalation retrieval failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/{escalation_id}/assign") +async def assign_escalation( + escalation_id: str, + analyst_username: str, + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app) +): + """Assign escalation to analyst.""" + try: + human_loop_agent = lg_sotf_app.workflow_engine.agents.get("human_loop") + if not human_loop_agent: + raise HTTPException(status_code=503, detail="Human loop agent not available") + + result = await human_loop_agent.assign_escalation( + escalation_id=escalation_id, + analyst_username=analyst_username + ) + return result + + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + logger.error(f"Escalation assignment failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/{escalation_id}/feedback") +async def submit_feedback( + escalation_id: str, + feedback: FeedbackRequest, + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app) +): + """Submit analyst feedback for escalation.""" + try: + human_loop_agent = lg_sotf_app.workflow_engine.agents.get("human_loop") + if not human_loop_agent: + raise HTTPException(status_code=503, detail="Human loop agent not available") + + result = await human_loop_agent.submit_feedback( + escalation_id=escalation_id, + analyst_username=feedback.analyst_username, + decision=feedback.decision, + confidence=feedback.confidence, + notes=feedback.notes, + actions_taken=feedback.actions_taken, + actions_recommended=feedback.actions_recommended, + triage_correct=feedback.triage_correct, + correlation_helpful=feedback.correlation_helpful, + analysis_accurate=feedback.analysis_accurate + ) + return result + + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + logger.error(f"Feedback submission failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/stats") +async def get_escalation_stats( + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app) +): + """Get escalation queue statistics.""" + try: + human_loop_agent = lg_sotf_app.workflow_engine.agents.get("human_loop") + if not human_loop_agent: + raise HTTPException(status_code=503, detail="Human loop agent not available") + + # Get queue stats from PostgreSQL + queue_stats = await human_loop_agent.get_queue_stats() + + # Try to get decision stats and accuracy, but don't fail if not available + try: + decision_stats = await human_loop_agent.get_decision_stats() + triage_accuracy = await human_loop_agent.get_triage_accuracy() + if triage_accuracy and 'accuracy_rate' in triage_accuracy: + queue_stats['accuracy_rate'] = triage_accuracy['accuracy_rate'] + except Exception as e: + logger.warning(f"Could not get decision stats: {e}") + + return queue_stats + + except HTTPException: + raise + except Exception as e: + logger.error(f"Escalation stats retrieval failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/lg_sotf/api/routers/ingestion.py b/src/lg_sotf/api/routers/ingestion.py new file mode 100644 index 00000000..5843d56a --- /dev/null +++ b/src/lg_sotf/api/routers/ingestion.py @@ -0,0 +1,233 @@ +"""Ingestion control and monitoring endpoints.""" + +import asyncio +import logging +from datetime import datetime, timedelta + +from fastapi import APIRouter, Depends, HTTPException + +from lg_sotf.app_initializer import LG_SOTFApplication +from lg_sotf.api.dependencies import get_lg_sotf_app, get_websocket_manager +from lg_sotf.api.models.ingestion import ( + IngestionStatusResponse, + IngestionControlRequest, +) +from lg_sotf.api.utils.websocket import WebSocketManager + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/ingestion", tags=["ingestion"]) + + +@router.get("/status", response_model=IngestionStatusResponse) +async def get_ingestion_status( + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app) +): + """Get current ingestion status and metrics.""" + try: + if not lg_sotf_app.workflow_engine or "ingestion" not in lg_sotf_app.workflow_engine.agents: + raise HTTPException(status_code=503, detail="Ingestion agent not available") + + ingestion_agent = ( + lg_sotf_app.workflow_engine.agents.get("ingestion_instance") or + lg_sotf_app.workflow_engine.agents.get("ingestion") + ) + ingestion_config = lg_sotf_app.config_manager.get_agent_config("ingestion") + + # Calculate next poll time + next_poll = None + if lg_sotf_app._last_ingestion_poll: + polling_interval = ingestion_config.get("polling_interval", 60) + next_poll = (lg_sotf_app._last_ingestion_poll + timedelta(seconds=polling_interval)).isoformat() + + # Check if ingestion is active (agent is initialized and has sources) + # Note: We check initialized + enabled sources instead of self.running + # because the app can run in API-only mode without the continuous loop + is_active = ingestion_agent.initialized and len(ingestion_agent.enabled_sources) > 0 + + return IngestionStatusResponse( + is_active=is_active, + last_poll_time=lg_sotf_app._last_ingestion_poll.isoformat() if lg_sotf_app._last_ingestion_poll else None, + next_poll_time=next_poll, + polling_interval=ingestion_config.get("polling_interval", 60), + sources_enabled=ingestion_agent.enabled_sources, + sources_stats=ingestion_agent.get_source_stats()["by_source"], + total_ingested=ingestion_agent.ingestion_stats["total_ingested"], + total_deduplicated=ingestion_agent.ingestion_stats["total_deduplicated"], + total_errors=ingestion_agent.ingestion_stats["total_errors"] + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Ingestion status retrieval failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/control") +async def control_ingestion( + request: IngestionControlRequest, + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app), + ws_manager: WebSocketManager = Depends(get_websocket_manager), +): + """Control ingestion process (trigger poll, etc).""" + try: + ingestion_agent = ( + lg_sotf_app.workflow_engine.agents.get("ingestion_instance") or + lg_sotf_app.workflow_engine.agents.get("ingestion") + ) + + if not ingestion_agent: + raise HTTPException(status_code=503, detail="Ingestion agent not available") + + if request.action == "trigger_poll": + # Actually poll - don't just reset the timer + logger.info("Manual ingestion poll triggered") + + try: + # Call poll_sources directly + new_alerts = await ingestion_agent.poll_sources() + + # Update last poll timestamp + lg_sotf_app._last_ingestion_poll = datetime.utcnow() + + logger.info(f"Manual poll found {len(new_alerts)} alerts") + + # Broadcast ingestion triggered event + await ws_manager.broadcast({ + "type": "ingestion_triggered", + "timestamp": datetime.utcnow().isoformat(), + "sources": request.sources or ingestion_agent.enabled_sources, + "alerts_found": len(new_alerts) + }, "ingestion_updates") + + # Process each alert through workflow + # NOTE: Manual trigger processes alerts immediately. If automatic polling + # is also running, the same alerts might be processed twice. This is by design + # for now - deduplication happens at the ingestion level, but alerts can still + # be processed multiple times if they arrive close together. + # The workflow engine should handle this gracefully via state versioning. + for alert in new_alerts: + try: + # Broadcast that alert was ingested + await ws_manager.broadcast({ + "type": "new_alert", + "alert_id": alert["id"], + "severity": alert.get("severity", "unknown"), + "source": alert.get("source", "unknown"), + "timestamp": datetime.utcnow().isoformat() + }, "new_alerts") + + # Process alert in background + asyncio.create_task( + _process_alert_background(alert["id"], alert, lg_sotf_app, ws_manager) + ) + + except Exception as e: + logger.error(f"Error processing alert {alert.get('id')}: {e}") + + return { + "status": "success", + "message": f"Ingestion poll completed - found {len(new_alerts)} alerts", + "alerts_found": len(new_alerts), + "timestamp": datetime.utcnow().isoformat() + } + + except Exception as e: + logger.error(f"Manual poll failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Poll failed: {str(e)}") + + elif request.action == "get_stats": + return ingestion_agent.get_source_stats() + + else: + raise HTTPException(status_code=400, detail=f"Unknown action: {request.action}") + + except HTTPException: + raise + except Exception as e: + logger.error(f"Ingestion control failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/sources") +async def get_ingestion_sources( + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app) +): + """Get all configured ingestion sources and their status.""" + try: + if not lg_sotf_app.workflow_engine or "ingestion" not in lg_sotf_app.workflow_engine.agents: + return {"sources": []} + + ingestion_agent = ( + lg_sotf_app.workflow_engine.agents.get("ingestion_instance") or + lg_sotf_app.workflow_engine.agents.get("ingestion") + ) + sources_info = [] + + for source_name, plugin in ingestion_agent.plugins.items(): + try: + is_healthy = await plugin.health_check() + metrics = plugin.get_metrics() + + sources_info.append({ + "name": source_name, + "enabled": plugin.enabled, + "healthy": is_healthy, + "initialized": plugin.initialized, + "fetch_count": metrics["fetch_count"], + "error_count": metrics["error_count"], + "last_fetch": metrics["last_fetch_time"] + }) + except Exception as e: + logger.warning(f"Error getting info for source {source_name}: {e}") + sources_info.append({ + "name": source_name, + "enabled": False, + "healthy": False, + "error": str(e) + }) + + return {"sources": sources_info} + + except Exception as e: + logger.error(f"Sources retrieval failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +# Helper functions + +async def _process_alert_background( + alert_id: str, + alert_data: dict, + lg_sotf_app: LG_SOTFApplication, + ws_manager: WebSocketManager, +): + """Process an alert through the workflow.""" + try: + await ws_manager.broadcast({ + "type": "alert_update", + "alert_id": alert_id, + "status": "processing", + "progress": 10 + }, "alert_updates") + + result = await lg_sotf_app.process_single_alert(alert_id, alert_data) + + await ws_manager.broadcast({ + "type": "alert_update", + "alert_id": alert_id, + "status": "completed", + "progress": 100, + "result": result + }, "alert_updates") + + except Exception as e: + logger.error(f"Background processing failed for {alert_id}: {e}") + + await ws_manager.broadcast({ + "type": "alert_update", + "alert_id": alert_id, + "status": "failed", + "error": str(e) + }, "alert_updates") diff --git a/src/lg_sotf/api/routers/metrics.py b/src/lg_sotf/api/routers/metrics.py new file mode 100644 index 00000000..d1cc0061 --- /dev/null +++ b/src/lg_sotf/api/routers/metrics.py @@ -0,0 +1,201 @@ +"""Health and metrics endpoints.""" + +import logging +from datetime import datetime, timedelta +from typing import Dict, List + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse + +from lg_sotf.app_initializer import LG_SOTFApplication +from lg_sotf.agents.registry import agent_registry +from lg_sotf.api.dependencies import get_lg_sotf_app +from lg_sotf.api.models.metrics import ( + MetricsResponse, + AgentStatusResponse, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1", tags=["metrics"]) + + +@router.get("/health") +async def health_check( + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app) +): + """Health check endpoint.""" + try: + app_health = await lg_sotf_app.health_check() + + # Get WebSocket manager from app state + from fastapi import Request + request: Request = lg_sotf_app # This will be properly injected + ws_manager = getattr(request.app.state, 'ws_manager', None) + ws_connections = len(ws_manager.active_connections) if ws_manager else 0 + + return { + "status": "healthy" if app_health else "unhealthy", + "timestamp": datetime.utcnow().isoformat(), + "version": "1.0.0", + "components": { + "lg_sotf_app": app_health, + "websocket_connections": ws_connections, + "api": True + } + } + except Exception as e: + return JSONResponse( + status_code=503, + content={ + "status": "unhealthy", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + } + ) + + +@router.get("/metrics", response_model=MetricsResponse) +async def get_system_metrics( + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app) +): + """Get system-wide metrics.""" + try: + return await _collect_system_metrics(lg_sotf_app) + except Exception as e: + logger.error(f"Metrics collection failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/agents/status", response_model=List[AgentStatusResponse]) +async def get_agents_status( + lg_sotf_app: LG_SOTFApplication = Depends(get_lg_sotf_app) +): + """Get status of all registered agents.""" + try: + agents_status = [] + stats = agent_registry.get_registry_stats() + + for agent_name in stats.get("agent_instances", []): + try: + agent = agent_registry.get_agent(agent_name) + metrics = agent.get_metrics() + is_healthy = await agent.health_check() if hasattr(agent, 'health_check') else agent.initialized + + agents_status.append(AgentStatusResponse( + agent_name=agent_name, + status="healthy" if is_healthy else "unhealthy", + last_execution=metrics.get("last_execution"), + success_rate=1.0 - metrics.get("error_rate", 0), + average_execution_time=metrics.get("avg_execution_time", 0), + error_count=metrics.get("error_count", 0) + )) + except Exception as e: + logger.warning(f"Agent {agent_name} status error: {e}") + agents_status.append(AgentStatusResponse( + agent_name=agent_name, + status="error", + last_execution=None, + success_rate=0.0, + average_execution_time=0.0, + error_count=1 + )) + + return agents_status + + except Exception as e: + logger.error(f"Agent status retrieval failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# Helper functions + +async def _collect_system_metrics(lg_sotf_app: LG_SOTFApplication) -> MetricsResponse: + """Collect comprehensive system metrics from PostgreSQL.""" + try: + app_status = lg_sotf_app.get_application_status() + storage = lg_sotf_app.postgres_storage + + cutoff_time = datetime.utcnow() - timedelta(hours=24) + + async with storage.pool.acquire() as conn: + alerts_today = await conn.fetchval(""" + SELECT COUNT(DISTINCT alert_id) + FROM states + WHERE created_at >= $1 + """, cutoff_time) + + alerts_in_progress = await conn.fetchval(""" + SELECT COUNT(DISTINCT s1.alert_id) + FROM states s1 + INNER JOIN ( + SELECT alert_id, MAX(version) as max_version + FROM states + GROUP BY alert_id + ) s2 ON s1.alert_id = s2.alert_id AND s1.version = s2.max_version + WHERE state_data->>'triage_status' IN ('processing', 'triaged', 'correlated', 'analyzed') + """) + + # Calculate average processing time (from first state to last state) + avg_processing_time = await conn.fetchval(""" + WITH alert_times AS ( + SELECT + alert_id, + MIN(created_at) as first_state, + MAX(created_at) as last_state + FROM states + WHERE created_at >= $1 + GROUP BY alert_id + ) + SELECT AVG(EXTRACT(EPOCH FROM (last_state - first_state))) + FROM alert_times + WHERE EXTRACT(EPOCH FROM (last_state - first_state)) > 0 + """, cutoff_time) + + # Calculate success rate (closed or responded / total) + total_completed = await conn.fetchval(""" + SELECT COUNT(DISTINCT s1.alert_id) + FROM states s1 + INNER JOIN ( + SELECT alert_id, MAX(version) as max_version + FROM states + WHERE created_at >= $1 + GROUP BY alert_id + ) s2 ON s1.alert_id = s2.alert_id AND s1.version = s2.max_version + WHERE state_data->>'triage_status' IN ('closed', 'responded') + """, cutoff_time) + + success_rate = (total_completed / alerts_today) if alerts_today and alerts_today > 0 else 0.0 + + agent_health = {} + stats = agent_registry.get_registry_stats() + for agent_name in stats.get("agent_instances", []): + try: + agent = agent_registry.get_agent(agent_name) + agent_health[agent_name] = await agent.health_check() if hasattr(agent, 'health_check') else agent.initialized + except Exception: + agent_health[agent_name] = False + + system_health = app_status.get("running", False) and app_status.get("initialized", False) + + return MetricsResponse( + timestamp=datetime.utcnow().isoformat(), + alerts_processed_today=alerts_today or 0, + alerts_in_progress=alerts_in_progress or 0, + average_processing_time=float(avg_processing_time) if avg_processing_time else 0.0, + success_rate=float(success_rate), + agent_health=agent_health, + system_health=system_health + ) + + except Exception as e: + logger.error(f"Metrics collection error: {e}") + return MetricsResponse( + timestamp=datetime.utcnow().isoformat(), + alerts_processed_today=0, + alerts_in_progress=0, + average_processing_time=0.0, + success_rate=0.0, + agent_health={}, + system_health=False + ) diff --git a/src/lg_sotf/api/routers/websocket.py b/src/lg_sotf/api/routers/websocket.py new file mode 100644 index 00000000..8c36f7a2 --- /dev/null +++ b/src/lg_sotf/api/routers/websocket.py @@ -0,0 +1,47 @@ +"""WebSocket endpoint for real-time updates.""" + +import json +import logging +from datetime import datetime + +from fastapi import APIRouter, WebSocket, WebSocketDisconnect + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["websocket"]) + + +@router.websocket("/ws/{client_id}") +async def websocket_endpoint( + websocket: WebSocket, + client_id: str +): + """WebSocket endpoint for real-time updates.""" + # Get WebSocket manager from app state + ws_manager = websocket.app.state.ws_manager + + await ws_manager.connect(websocket, client_id) + + try: + while True: + data = await websocket.receive_text() + message = json.loads(data) + + if message.get("type") == "subscribe": + subscriptions = message.get("subscriptions", []) + ws_manager.client_subscriptions[client_id] = subscriptions + + await ws_manager.send_personal_message({ + "type": "subscription_confirmed", + "subscriptions": subscriptions + }, client_id) + + elif message.get("type") == "ping": + await ws_manager.send_personal_message({ + "type": "pong", + "timestamp": datetime.utcnow().isoformat() + }, client_id) + + except WebSocketDisconnect: + ws_manager.disconnect(client_id) + logger.info(f"Client {client_id} disconnected") diff --git a/src/lg_sotf/api/utils/__init__.py b/src/lg_sotf/api/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/lg_sotf/api/utils/data_processing.py b/src/lg_sotf/api/utils/data_processing.py new file mode 100644 index 00000000..e69de29b diff --git a/src/lg_sotf/api/utils/websocket.py b/src/lg_sotf/api/utils/websocket.py new file mode 100644 index 00000000..d3050dec --- /dev/null +++ b/src/lg_sotf/api/utils/websocket.py @@ -0,0 +1,91 @@ +"""WebSocket connection manager.""" + +import asyncio +import json +import logging +from datetime import datetime +from typing import Dict, List + +from fastapi import WebSocket + + +class WebSocketManager: + """Manages WebSocket connections and message broadcasting.""" + + def __init__(self): + self.active_connections: Dict[str, WebSocket] = {} + self.client_subscriptions: Dict[str, List[str]] = {} + self.heartbeat_task = None + self.logger = logging.getLogger(__name__) + + async def connect(self, websocket: WebSocket, client_id: str): + """Accept and register a new WebSocket connection.""" + await websocket.accept() + self.active_connections[client_id] = websocket + self.client_subscriptions[client_id] = [] + + await self.send_personal_message({ + "type": "connection", + "status": "connected", + "client_id": client_id, + "server_time": datetime.utcnow().isoformat() + }, client_id) + + def disconnect(self, client_id: str): + """Disconnect and unregister a WebSocket connection.""" + if client_id in self.active_connections: + del self.active_connections[client_id] + if client_id in self.client_subscriptions: + del self.client_subscriptions[client_id] + + async def send_personal_message(self, message: dict, client_id: str): + """Send a message to a specific client.""" + if client_id not in self.active_connections: + return + + try: + await self.active_connections[client_id].send_text( + json.dumps(message, default=str) + ) + except Exception as e: + self.logger.warning(f"Failed to send to {client_id}: {e}") + self.disconnect(client_id) + + async def broadcast(self, message: dict, subscription_type: str = None): + """Broadcast a message to all connected clients.""" + disconnected = [] + + for client_id, websocket in self.active_connections.items(): + if subscription_type: + subscriptions = self.client_subscriptions.get(client_id, []) + if subscription_type not in subscriptions: + continue + + try: + await websocket.send_text(json.dumps(message, default=str)) + except Exception as e: + self.logger.warning(f"Broadcast error to {client_id}: {e}") + disconnected.append(client_id) + + for client_id in disconnected: + self.disconnect(client_id) + + async def heartbeat_loop(self): + """Send periodic heartbeat messages to all clients.""" + while True: + try: + await asyncio.sleep(30) + + message = { + "type": "heartbeat", + "timestamp": datetime.utcnow().isoformat(), + "active_connections": len(self.active_connections) + } + + await self.broadcast(message) + + except asyncio.CancelledError: + self.logger.info("Heartbeat loop cancelled") + break + except Exception as e: + self.logger.error(f"Heartbeat error: {e}") diff --git a/src/lg_sotf/app_initializer.py b/src/lg_sotf/app_initializer.py new file mode 100644 index 00000000..cdd562f8 --- /dev/null +++ b/src/lg_sotf/app_initializer.py @@ -0,0 +1,689 @@ +""" +Application initializer for LG-SOTF. + +This module provides the main application lifecycle management including: +- Component initialization and dependency injection +- Continuous alert ingestion and processing +- Health monitoring and metrics collection +- Graceful shutdown with resource cleanup +""" + +import asyncio +import logging +import signal +from datetime import datetime, timedelta +from typing import Optional, Set + +from lg_sotf.audit.logger import AuditLogger +from lg_sotf.audit.metrics import MetricsCollector +from lg_sotf.core.config.manager import ConfigManager +from lg_sotf.core.exceptions import LG_SOTFError +from lg_sotf.core.state.manager import StateManager +from lg_sotf.core.workflow import WorkflowEngine +from lg_sotf.storage.postgres import PostgreSQLStorage +from lg_sotf.storage.redis import RedisStorage + + +BANNER = """ +╔══════════════════════════════════════════════════════════════════════════════════════════╗ +║ ║ +║ ⠀⠀⠀⠀⡀⠀⠀⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠀⠀⠀⠀⠀⠀⡀⠀⠀⠀⠀⠀⠀⠀⠀ +║ ⠀⢸⠉⣹⠋⠉⢉⡟⢩⢋⠋⣽⡻⠭⢽⢉⠯⠭⠭⠭⢽⡍⢹⡍⠙⣯⠉⠉⠉⠉⠉⣿⢫⠉⠉⠉⢉⡟⠉⢿⢹⠉⢉⣉⢿⡝⡉⢩⢿⣻⢍⠉⠉⠩⢹⣟⡏⠉⠹⡉⢻⡍⡇ +║ ⠀⢸⢠⢹⠀⠀⢸⠁⣼⠀⣼⡝⠀⠀⢸⠘⠀⠀⠀⠀⠈⢿⠀⡟⡄⠹⣣⠀⠀⠐⠀⢸⡘⡄⣤⠀⡼⠁⠀⢺⡘⠉⠀⠀⠀⠫⣪⣌⡌⢳⡻⣦⠀⠀⢃⡽⡼⡀⠀⢣⢸⠸⡇ +║ ⠀⢸⡸⢸⠀⠀⣿⠀⣇⢠⡿⠀⠀⠀⠸⡇⠀⠀⠀⠀⠀⠘⢇⠸⠘⡀⠻⣇⠀⠀⠄⠀⡇⢣⢛⠀⡇⠀⠀⣸⠇⠀⠀⠀⠀⠀⠘⠄⢻⡀⠻⣻⣧⠀⠀⠃⢧⡇⠀⢸⢸⡇⡇ +║ ⠀⢸⡇⢸⣠⠀⣿⢠⣿⡾⠁⠀⢀⡀⠤⢇⣀⣐⣀⠀⠤⢀⠈⠢⡡⡈⢦⡙⣷⡀⠀⠀⢿⠈⢻⣡⠁⠀⢀⠏⠀⠀⠀⢀⠀⠄⣀⣐⣀⣙⠢⡌⣻⣷⡀⢹⢸⡅⠀⢸⠸⡇⡇ +║ ⠀⢸⡇⢸⣟⠀⢿⢸⡿⠀⣀⣶⣷⣾⡿⠿⣿⣿⣿⣿⣿⣶⣬⡀⠐⠰⣄⠙⠪⣻⣦⡀⠘⣧⠀⠙⠄⠀⠀⠀⠀⠀⣨⣴⣾⣿⠿⣿⣿⣿⣿⣿⣶⣯⣿⣼⢼⡇⠀⢸⡇⡇⡇ +║ ⠀⢸⢧⠀⣿⡅⢸⣼⡷⣾⣿⡟⠋⣿⠓⢲⣿⣿⣿⡟⠙⣿⠛⢯⡳⡀⠈⠓⠄⡈⠚⠿⣧⣌⢧⠀⠀⠀⠀⠀⣠⣺⠟⢫⡿⠓⢺⣿⣿⣿⠏⠙⣏⠛⣿⣿⣾⡇⢀⡿⢠⠀⡇ +║ ⠀⢸⢸⠀⢹⣷⡀⢿⡁⠀⠻⣇⠀⣇⠀⠘⣿⣿⡿⠁⠐⣉⡀⠀⠁⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠳⠄⠀⠀⠀⠀⠋⠀⠘⡇⠀⠸⣿⣿⠟⠀⢈⣉⢠⡿⠁⣼⠁⣼⠃⣼⠀⡇ +║ ⠀⢸⠸⣀⠈⣯⢳⡘⣇⠀⠀⠈⡂⣜⣆⡀⠀⠀⢀⣀⡴⠇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢽⣆⣀⠀⠀⠀⣀⣜⠕⡊⠀⣸⠇⣼⡟⢠⠏⠀⡇ +║ ⠀⢸⠀⡟⠀⢸⡆⢹⡜⡆⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠋⣾⡏⡇⡎⡇⠀⡇ +║ ⠀⢸⠀⢃⡆⠀⢿⡄⠑⢽⣄⠀⠀⠀⢀⠂⠠⢁⠈⠄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠠⠂⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡀⠀⠄⡐⢀⠂⠀⠀⣠⣮⡟⢹⣯⣸⣱⠁⠀⡇ +║ ⠀⠈⠉⠉⠋⠉⠉⠋⠉⠉⠉⠋⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠋⡟⠉⠉⡿⠋⠋⠋⠉⠉⠁ +║ +║ LangGraph SOC Triage & Orchestration Framework (LG-SOTF) +║ Version 1.0.0 - Copy Right 2025 ║ +╚══════════════════════════════════════════════════════════════════════════════════════════╝ +""" + +class LG_SOTFApplication: + """Main application class for LG-SOTF.""" + + def __init__(self, config_path: Optional[str] = None, setup_signal_handlers: bool = True): + """Initialize the application. + + Args: + config_path: Path to configuration file + setup_signal_handlers: Whether to setup signal handlers (disable when running under uvicorn) + """ + self.config_path = config_path + self.config_manager = None + self.state_manager = None + self.workflow_engine = None + self.audit_logger = None + self.metrics = None + self.postgres_storage = None + self.redis_storage = None + + # Application state + self.running = False + self.initialized = False + + # Task tracking for graceful shutdown + self._active_tasks: Set[asyncio.Task] = set() + self._shutdown_event = asyncio.Event() + + # Ingestion tracking + self._last_ingestion_poll: Optional[datetime] = None + self._last_health_check: Optional[datetime] = None + self._ingestion_lock = asyncio.Lock() + + # Setup signal handlers (only when not running under uvicorn) + if setup_signal_handlers: + self._setup_signal_handlers() + + def _setup_signal_handlers(self): + """Setup signal handlers for graceful shutdown.""" + def signal_handler(signum, frame): + """Handle shutdown signals.""" + logging.info(f"Received signal {signum}, initiating graceful shutdown...") + self.running = False + # Use call_soon_threadsafe to safely set the event from signal handler + try: + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(self._shutdown_event.set) + except RuntimeError: + # If no loop is running, just set it directly (shouldn't happen in normal flow) + self._shutdown_event.set() + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + async def initialize(self): + """Initialize all application components. + + Raises: + LG_SOTFError: If initialization fails + """ + try: + logging.info("=== Initializing LG-SOTF Application ===") + + # Load configuration + self.config_manager = ConfigManager(self.config_path) + logging.info("✓ Configuration loaded") + + # Validate configuration + self.config_manager.validate() + logging.info("✓ Configuration validated") + + # Initialize audit and metrics + self.audit_logger = AuditLogger() + self.metrics = MetricsCollector(self.config_manager) + logging.info("✓ Audit and metrics initialized") + + # Initialize storage backends + await self._initialize_storage() + logging.info("✓ Storage backends initialized") + + # Initialize state manager + self.state_manager = StateManager(self.postgres_storage) + logging.info("✓ State manager initialized") + + # Initialize workflow engine (handles agent initialization) + # Pass Redis storage and let the workflow engine create the tool orchestrator + self.workflow_engine = WorkflowEngine( + self.config_manager, + self.state_manager, + redis_storage=self.redis_storage, + tool_orchestrator=None # Created internally by workflow engine + ) + await self.workflow_engine.initialize() + logging.info("✓ Workflow engine initialized") + + # Verify agents are registered + from lg_sotf.agents.registry import agent_registry + stats = agent_registry.get_registry_stats() + logging.info( + f"✓ Agent registry: {stats['agent_types_count']} types, " + f"{stats['agent_instances_count']} instances, " + f"{len(stats['initialized_agents'])} initialized" + ) + + # Log application start + self.audit_logger.log_application_start( + config_path=self.config_path, + version="0.1.0" + ) + + self.initialized = True + logging.info(BANNER) + + except Exception as e: + logging.error(f"Failed to initialize application: {e}", exc_info=True) + raise LG_SOTFError(f"Application initialization failed: {e}") + + async def _initialize_storage(self): + """Initialize storage backends. + + Raises: + Exception: If storage initialization fails + """ + try: + # Initialize PostgreSQL + db_config = self.config_manager.get_database_config() + connection_string = ( + f"postgresql://{db_config.username}:{db_config.password}@" + f"{db_config.host}:{db_config.port}/{db_config.database}" + ) + + self.postgres_storage = PostgreSQLStorage(connection_string) + await self.postgres_storage.initialize() + logging.info(f" - PostgreSQL connected: {db_config.host}:{db_config.port}") + + # Initialize Redis + redis_config = self.config_manager.get_redis_config() + redis_password = f":{redis_config.password}@" if redis_config.password else "" + redis_connection_string = ( + f"redis://{redis_password}{redis_config.host}:{redis_config.port}/{redis_config.db}" + ) + + self.redis_storage = RedisStorage(redis_connection_string) + await self.redis_storage.initialize() + logging.info(f" - Redis connected: {redis_config.host}:{redis_config.port}") + + except Exception as e: + logging.error(f"Failed to initialize storage: {e}", exc_info=True) + raise + + async def run(self): + """Run the main application loop. + + This method handles: + - Continuous alert ingestion and processing + - Periodic health checks + - Graceful shutdown on signal + """ + try: + self.running = True + logging.info("🚀 LG-SOTF Application Started") + logging.info("Press Ctrl+C to shutdown gracefully\n") + + # Create background tasks + ingestion_task = asyncio.create_task(self._ingestion_loop()) + health_check_task = asyncio.create_task(self._health_check_loop()) + + # Wait for shutdown signal + await self._shutdown_event.wait() + + # Cancel background tasks + logging.info("Stopping background tasks...") + ingestion_task.cancel() + health_check_task.cancel() + + # Wait for tasks to complete + await asyncio.gather(ingestion_task, health_check_task, return_exceptions=True) + + except asyncio.CancelledError: + logging.info("Main loop cancelled") + except Exception as e: + logging.error(f"Unexpected error in main loop: {e}", exc_info=True) + finally: + await self.shutdown() + + async def _ingestion_loop(self): + """Continuous ingestion loop. + + Polls configured sources at regular intervals and processes alerts. + """ + try: + while self.running: + try: + await self._process_alerts() + except asyncio.CancelledError: + raise + except Exception as e: + logging.error(f"Error in ingestion loop: {e}", exc_info=True) + self.metrics.increment_counter("ingestion_loop_errors") + await asyncio.sleep(5) # Back off on error + + # Sleep briefly to prevent tight loop + await asyncio.sleep(1) + + except asyncio.CancelledError: + logging.info("Ingestion loop cancelled") + + async def _health_check_loop(self): + """Periodic health check loop. + + Performs system health checks at regular intervals. + Responsive to shutdown signals. + """ + try: + while self.running: + try: + await self._perform_health_checks() + except asyncio.CancelledError: + raise + except Exception as e: + logging.error(f"Error in health check loop: {e}", exc_info=True) + self.metrics.increment_counter("health_check_errors") + + # Wait before next health check, but wake up on shutdown + try: + await asyncio.wait_for(self._shutdown_event.wait(), timeout=60) + # If we get here, shutdown was triggered + break + except asyncio.TimeoutError: + # Timeout is normal, continue to next health check + pass + + except asyncio.CancelledError: + logging.info("Health check loop cancelled") + + async def _process_alerts(self): + """Process alerts from ingestion sources. + + This method: + - Respects polling interval configuration + - Enforces max concurrent alert limit + - Tracks tasks for graceful shutdown + """ + # Check if ingestion agent is available + if not self.workflow_engine or "ingestion" not in self.workflow_engine.agents: + return + + ingestion_agent = self.workflow_engine.agents["ingestion"] + + # Get polling configuration + ingestion_config = self.config_manager.get_agent_config("ingestion") + polling_interval = ingestion_config.get("polling_interval", 60) + max_concurrent = ingestion_agent.max_concurrent_alerts + + # Check if it's time to poll + if self._last_ingestion_poll is not None: + time_since_poll = (datetime.utcnow() - self._last_ingestion_poll).total_seconds() + if time_since_poll < polling_interval: + return + + # Use lock to prevent concurrent polling + if self._ingestion_lock.locked(): + return + + async with self._ingestion_lock: + try: + # Check active task count + active_count = len(self._active_tasks) + if active_count >= max_concurrent: + logging.warning( + f"Max concurrent alerts reached ({active_count}/{max_concurrent}), " + "skipping this poll cycle" + ) + self.metrics.increment_counter("ingestion_poll_skipped_max_concurrent") + return + + # Poll for new alerts + logging.debug("Polling ingestion sources...") + new_alerts = await ingestion_agent.poll_sources() + + if not new_alerts: + self._last_ingestion_poll = datetime.utcnow() + return + + logging.info(f"📥 Ingestion: Found {len(new_alerts)} new alerts") + self.metrics.increment_counter("ingestion_alerts_received", len(new_alerts)) + + # Process alerts respecting concurrency limit + processed_count = 0 + for alert in new_alerts: + # Check if we can process more alerts + if len(self._active_tasks) >= max_concurrent: + remaining = len(new_alerts) - processed_count + logging.warning( + f"Max concurrent limit reached, " + f"queueing {remaining} alerts for next cycle" + ) + self.metrics.increment_counter("ingestion_alerts_queued", remaining) + break + + try: + # Create workflow task + task = asyncio.create_task( + self._process_single_workflow(alert["id"], alert) + ) + + # Track task + self._active_tasks.add(task) + task.add_done_callback(self._active_tasks.discard) + + processed_count += 1 + + except Exception as e: + logging.error( + f"Failed to create workflow task for alert {alert.get('id', 'unknown')}: {e}", + exc_info=True + ) + self.metrics.increment_counter("workflow_creation_errors") + + logging.info( + f"✓ Created {processed_count} workflow tasks " + f"({len(self._active_tasks)} active)" + ) + self.metrics.set_gauge("active_workflow_tasks", len(self._active_tasks)) + + # Update last poll time + self._last_ingestion_poll = datetime.utcnow() + self.metrics.record_histogram("ingestion_poll_interval", polling_interval) + + except Exception as e: + logging.error(f"Ingestion polling error: {e}", exc_info=True) + self.metrics.increment_counter("ingestion_poll_errors") + + async def _process_single_workflow(self, alert_id: str, alert_data: dict): + """Process a single alert through the workflow. + + Args: + alert_id: Alert identifier + alert_data: Alert data dictionary (already ingested by polling loop) + """ + start_time = datetime.utcnow() + + try: + logging.debug(f"Processing workflow for alert {alert_id}") + + # Skip ingestion node since alert is already ingested by the polling loop + result = await self.workflow_engine.execute_workflow( + alert_id, + alert_data, + skip_ingestion=True # Alert already normalized by ingestion agent polling + ) + + # Calculate processing time + processing_time = (datetime.utcnow() - start_time).total_seconds() + + logging.info( + f"✓ Alert {alert_id} processed: " + f"status={result.get('triage_status', 'unknown')}, " + f"confidence={result.get('confidence_score', 0)}, " + f"time={processing_time:.2f}s" + ) + + # Record metrics + self.metrics.increment_counter("workflow_success") + self.metrics.record_histogram("workflow_processing_time", processing_time) + + except asyncio.CancelledError: + logging.info(f"Workflow for alert {alert_id} cancelled (shutdown)") + raise + except Exception as e: + processing_time = (datetime.utcnow() - start_time).total_seconds() + logging.error( + f"✗ Failed to process alert {alert_id}: {e}", + exc_info=True + ) + self.metrics.increment_counter("workflow_errors") + self.metrics.record_histogram("workflow_error_time", processing_time) + + async def _perform_health_checks(self): + """Perform periodic health checks on all components.""" + # Check if it's time for health check (every 60 seconds) + if self._last_health_check is not None: + time_since_check = (datetime.utcnow() - self._last_health_check).total_seconds() + if time_since_check < 60: + return + + try: + logging.debug("Performing health checks...") + health_status = await self.health_check() + + if health_status: + logging.debug("✓ All components healthy") + else: + logging.warning("⚠ Some components unhealthy") + + self.metrics.set_gauge("health_check_status", 1 if health_status else 0) + self._last_health_check = datetime.utcnow() + + except Exception as e: + logging.error(f"Health check error: {e}", exc_info=True) + self.metrics.increment_counter("health_check_errors") + + async def shutdown(self): + """Shutdown the application gracefully. + + This method: + - Cancels all active workflow tasks + - Shuts down agents + - Closes storage connections + - Cleans up resources + """ + try: + logging.info("\n=== Shutting Down LG-SOTF Application ===") + + # Cancel active workflow tasks + if self._active_tasks: + task_count = len(self._active_tasks) + logging.info(f"Cancelling {task_count} active workflow tasks...") + + for task in self._active_tasks: + if not task.done(): + task.cancel() + + # Wait for tasks to complete with timeout + try: + await asyncio.wait_for( + asyncio.gather(*self._active_tasks, return_exceptions=True), + timeout=10.0 + ) + logging.info(f"✓ All {task_count} workflow tasks cancelled") + except asyncio.TimeoutError: + logging.warning(f"⚠ Some workflow tasks did not complete within timeout") + + # Log application shutdown + if self.audit_logger: + self.audit_logger.log_application_shutdown() + + # Shutdown agents + await self._shutdown_agents() + + # Close storage connections + await self._shutdown_storage() + + # Shutdown metrics collection + if self.metrics: + try: + self.metrics.shutdown() + logging.info("✓ Metrics collection stopped") + except Exception as e: + logging.warning(f"⚠ Error shutting down metrics: {e}") + + logging.info("=== LG-SOTF Application Shutdown Complete ===\n") + + except Exception as e: + logging.error(f"Error during shutdown: {e}", exc_info=True) + + async def _shutdown_agents(self): + """Shutdown all registered agents.""" + try: + from lg_sotf.agents.registry import agent_registry + + logging.info("Shutting down agents...") + await agent_registry.cleanup_all_agents() + logging.info("✓ All agents stopped") + + except Exception as e: + logging.warning(f"⚠ Error shutting down agents: {e}") + + async def _shutdown_storage(self): + """Shutdown storage connections.""" + storage_tasks = [] + + # Schedule PostgreSQL cleanup + if self.postgres_storage: + try: + storage_tasks.append( + asyncio.create_task(self.postgres_storage.close()) + ) + except Exception as e: + logging.warning(f"⚠ Error scheduling PostgreSQL close: {e}") + + # Schedule Redis cleanup + if self.redis_storage: + try: + storage_tasks.append( + asyncio.create_task(self.redis_storage.close()) + ) + except Exception as e: + logging.warning(f"⚠ Error scheduling Redis close: {e}") + + # Wait for storage cleanup with timeout + if storage_tasks: + try: + await asyncio.wait_for( + asyncio.gather(*storage_tasks, return_exceptions=True), + timeout=5.0 + ) + logging.info("✓ Storage connections closed") + except asyncio.TimeoutError: + logging.warning("⚠ Storage cleanup timed out") + except Exception as e: + logging.warning(f"⚠ Error during storage cleanup: {e}") + + async def health_check(self) -> bool: + """Perform comprehensive health check. + + Returns: + bool: True if all components healthy, False otherwise + """ + try: + health_results = { + 'config_manager': False, + 'state_manager': False, + 'workflow_engine': False, + 'postgres_storage': False, + 'redis_storage': False, + 'agents': False + } + + # Check configuration + if self.config_manager: + health_results['config_manager'] = True + + # Check state manager + if self.state_manager: + health_results['state_manager'] = True + + # Check workflow engine + if self.workflow_engine: + health_results['workflow_engine'] = True + + # Check PostgreSQL + if self.postgres_storage: + health_results['postgres_storage'] = await self.postgres_storage.health_check() + + # Check Redis + if self.redis_storage: + health_results['redis_storage'] = await self.redis_storage.health_check() + + # Check agents + try: + from lg_sotf.agents.registry import agent_registry + + # Check if any agent is healthy + if agent_registry.agent_exists("ingestion_instance"): + ingestion_agent = agent_registry.get_agent("ingestion_instance") + if hasattr(ingestion_agent, 'health_check'): + health_results['agents'] = await ingestion_agent.health_check() + else: + health_results['agents'] = ingestion_agent.initialized + + except Exception as e: + logging.debug(f"Agent health check error: {e}") + health_results['agents'] = False + + # Record component health metrics + if self.metrics: + for component, status in health_results.items(): + self.metrics.set_gauge(f"health_{component}", 1 if status else 0) + + # Calculate overall health + overall_health = all(health_results.values()) + + # Log unhealthy components + unhealthy = [comp for comp, status in health_results.items() if not status] + if unhealthy: + logging.debug(f"Unhealthy components: {', '.join(unhealthy)}") + + return overall_health + + except Exception as e: + logging.error(f"Health check failed: {e}", exc_info=True) + return False + + async def process_single_alert(self, alert_id: str, alert_data: dict) -> dict: + """Process a single alert through the workflow. + + This method is used for testing and manual alert processing. + + Args: + alert_id: Alert identifier + alert_data: Alert data dictionary + + Returns: + dict: Workflow result + + Raises: + LG_SOTFError: If workflow engine not initialized + """ + try: + if not self.workflow_engine: + raise LG_SOTFError("Workflow engine not initialized") + + logging.info(f"Processing single alert: {alert_id}") + + result = await self.workflow_engine.execute_workflow(alert_id, alert_data) + + logging.info(f"Alert {alert_id} processed successfully") + return result + + except Exception as e: + logging.error(f"Failed to process alert {alert_id}: {e}", exc_info=True) + raise + + def get_application_status(self) -> dict: + """Get comprehensive application status. + + Returns: + dict: Application status information + """ + try: + status = { + 'running': self.running, + 'initialized': self.initialized, + 'active_workflow_tasks': len(self._active_tasks), + 'last_ingestion_poll': self._last_ingestion_poll.isoformat() if self._last_ingestion_poll else None, + 'last_health_check': self._last_health_check.isoformat() if self._last_health_check else None, + 'components': { + 'config_manager': self.config_manager is not None, + 'state_manager': self.state_manager is not None, + 'workflow_engine': self.workflow_engine is not None, + 'audit_logger': self.audit_logger is not None, + 'metrics': self.metrics is not None + }, + 'storage': { + 'postgres': self.postgres_storage is not None, + 'redis': self.redis_storage is not None + } + } + + # Add agent status + try: + from lg_sotf.agents.registry import agent_registry + status['agents'] = agent_registry.get_registry_stats() + except Exception as e: + logging.warning(f"Error getting agent status: {e}") + status['agents'] = {'error': str(e)} + + return status + + except Exception as e: + logging.error(f"Error getting application status: {e}", exc_info=True) + return {'error': str(e)} diff --git a/src/lg_sotf/core/graph.py b/src/lg_sotf/core/graph.py new file mode 100644 index 00000000..dfb9f29c --- /dev/null +++ b/src/lg_sotf/core/graph.py @@ -0,0 +1,374 @@ +""" +LangGraph workflow graph construction and routing logic. + +This module defines the graph structure, state schema, and routing decisions +for the SOC alert processing workflow. +""" + +import operator +from typing import Any, Dict, List, Literal, TypedDict, Annotated + +from pydantic import BaseModel, Field +from langgraph.graph import END, START, StateGraph + +from lg_sotf.core.config.manager import ConfigManager + + +class ExecutionContextData(TypedDict): + """Typed execution context for state.""" + execution_id: str + started_at: str + last_node: str + executed_nodes: List[str] + execution_time: str + + +class RoutingDecision(BaseModel): + """Structured LLM routing decision following LangGraph best practices.""" + next_step: Literal["correlation", "analysis", "response", "human_loop", "close"] = Field( + description="The next processing step for the alert" + ) + confidence: int = Field(ge=0, le=100, description="Confidence in this routing decision") + reasoning: str = Field(description="Brief reasoning for this routing choice") + + +class WorkflowState(TypedDict): + """State schema for LangGraph workflow with proper reducers.""" + # Core identification + alert_id: str + workflow_instance_id: str + execution_context: ExecutionContextData + + # Alert data + raw_alert: Dict[str, Any] + enriched_data: Dict[str, Any] + + # Status and scoring + triage_status: str + confidence_score: int + current_node: str + priority_level: int + + # Indicators - WITH REDUCERS for accumulation + fp_indicators: Annotated[List[str], operator.add] + tp_indicators: Annotated[List[str], operator.add] + + # Correlation data - WITH REDUCER + correlations: Annotated[List[Dict[str, Any]], operator.add] + correlation_score: int + + # Analysis data + analysis_conclusion: str + threat_score: int + recommended_actions: Annotated[List[str], operator.add] + analysis_reasoning: Annotated[List[Dict[str, Any]], operator.add] + tool_results: Dict[str, Dict[str, Any]] + + # Processing tracking - WITH REDUCER + processing_notes: Annotated[List[str], operator.add] + last_updated: str + + # Execution guards + agent_executions: Dict[str, Dict[str, Any]] # Track agent execution state + state_version: int # State versioning for conflict detection + + +class WorkflowGraphBuilder: + """Builder for constructing the LangGraph workflow graph.""" + + def __init__(self, config_manager: ConfigManager): + """Initialize graph builder with configuration.""" + self.config = config_manager + self.routing_config = { + 'max_alert_age_hours': config_manager.get('routing.max_alert_age_hours', 72), + 'correlation_grey_zone_min': config_manager.get('routing.correlation_grey_zone_min', 30), + 'correlation_grey_zone_max': config_manager.get('routing.correlation_grey_zone_max', 70), + 'analysis_threshold': config_manager.get('routing.analysis_threshold', 40), + 'human_review_min': config_manager.get('routing.human_review_min', 20), + 'human_review_max': config_manager.get('routing.human_review_max', 60), + 'response_threshold': config_manager.get('routing.response_threshold', 80), + } + + def build_graph(self, node_executors: Dict[str, Any]) -> StateGraph: + """Build the LangGraph workflow graph. + + Args: + node_executors: Dictionary mapping node names to executor functions + + Returns: + Compiled StateGraph ready for execution + """ + workflow = StateGraph(WorkflowState) + + # Add nodes with execution wrappers + for node_name, executor in node_executors.items(): + workflow.add_node(node_name, executor) + + # Set entry point + workflow.add_edge(START, "ingestion") + + # Add conditional edges with routing + workflow.add_conditional_edges( + "ingestion", + self.route_after_ingestion, + {"triage": "triage", "close": "close"}, + ) + + workflow.add_conditional_edges( + "triage", + self.route_after_triage, + { + "correlation": "correlation", + "analysis": "analysis", + "human_loop": "human_loop", + "response": "response", + "close": "close", + }, + ) + + workflow.add_conditional_edges( + "correlation", + self.route_after_correlation, + { + "analysis": "analysis", + "response": "response", + "human_loop": "human_loop", + "close": "close", + }, + ) + + workflow.add_conditional_edges( + "analysis", + self.route_after_analysis, + {"human_loop": "human_loop", "response": "response", "close": "close"}, + ) + + workflow.add_conditional_edges( + "human_loop", + self.route_after_human_loop, + {"analysis": "analysis", "response": "response", "close": "close"}, + ) + + workflow.add_conditional_edges( + "response", + self.route_after_response, + {"learning": "learning", "close": "close"}, + ) + + workflow.add_edge("learning", "close") + workflow.add_edge("close", END) + + # Validate graph structure + self._validate_graph(workflow) + + return workflow + + def _validate_graph(self, workflow: StateGraph): + """Validate graph structure before compilation. + + Args: + workflow: StateGraph to validate + + Raises: + ValueError: If graph structure is invalid + """ + # LangGraph will do most validation on compile(), + # but we can add custom checks here + if not hasattr(workflow, 'nodes') or len(workflow.nodes) == 0: + raise ValueError("Graph has no nodes defined") + + # =============================== + # ROUTING METHODS + # =============================== + + async def route_after_triage( + self, + state: WorkflowState + ) -> Literal["correlation", "analysis", "response", "human_loop", "close"]: + """Intelligent async routing after triage. + + Following LangGraph best practices: + - ALWAYS route through correlation for threat intelligence building + - Fast-track obvious FPs to close + - All other alerts go through correlation + """ + confidence = state["confidence_score"] + fp_count = len(state["fp_indicators"]) + tp_count = len(state["tp_indicators"]) + + # ONLY obvious false positives skip correlation and go directly to close + if confidence <= 10 and fp_count >= 2: + return "close" + + if fp_count > tp_count and fp_count >= 3 and confidence <= 20: + return "close" + + # ALL other alerts MUST go through correlation first to build threat intelligence + return "correlation" + + def route_after_correlation(self, state: WorkflowState) -> str: + """Routing after correlation.""" + correlations = state.get("correlations", []) + correlation_score = state.get("correlation_score", 0) + confidence = state["confidence_score"] + + # Strong correlations → direct response + if correlation_score > 85 and len(correlations) >= 5: + return "response" + + # Moderate correlations → analysis + if correlation_score > 60 or len(correlations) >= 3: + return "analysis" + + # Weak correlations → human review + if correlation_score > 20 and confidence > 50: + return "human_loop" + + # No meaningful correlations → close + return "close" + + def route_after_analysis(self, state: WorkflowState) -> str: + """Routing after analysis.""" + threat_score = state.get("threat_score", 0) + confidence = state["confidence_score"] + conclusion = state.get("analysis_conclusion", "").lower() + + # High threat → response + if threat_score >= 80 or (threat_score >= 60 and confidence >= 80): + return "response" + + # Uncertain analysis → human review + if "uncertain" in conclusion or (30 <= confidence <= 70): + return "human_loop" + + # Low threat → close + return "close" + + def route_after_ingestion(self, state: WorkflowState) -> str: + """Routing logic after ingestion.""" + return "triage" if state["raw_alert"] else "close" + + def route_after_human_loop(self, state: WorkflowState) -> str: + """Routing logic after human loop.""" + confidence = state["confidence_score"] + return "response" if confidence >= 75 else "close" + + def route_after_response(self, state: WorkflowState) -> str: + """Routing logic after response.""" + return "learning" if self._should_learn(state) else "close" + + # =============================== + # ROUTING HELPER METHODS + # =============================== + + def _needs_correlation(self, state: WorkflowState) -> bool: + """Determine if alert needs correlation based on indicators.""" + enriched_data = state.get("enriched_data", {}) + raw_alert = state.get("raw_alert", {}) + + # Check for network indicators + has_network_indicators = any([ + enriched_data.get("source_ip"), + enriched_data.get("destination_ip"), + enriched_data.get("domain"), + enriched_data.get("url"), + raw_alert.get("source_ip"), + raw_alert.get("destination_ip"), + raw_alert.get("domain") + ]) + + # Check for user indicators + has_user_indicators = any([ + enriched_data.get("user"), + enriched_data.get("username"), + raw_alert.get("user"), + raw_alert.get("username") + ]) + + # Check for file/hash indicators + has_file_indicators = any([ + enriched_data.get("file_hash"), + enriched_data.get("file_name"), + raw_alert.get("file_hash"), + raw_alert.get("sha256"), + raw_alert.get("md5") + ]) + + return has_network_indicators or has_user_indicators or has_file_indicators + + def _needs_analysis(self, state: WorkflowState) -> bool: + """Determine if alert needs deep analysis.""" + confidence = state.get("confidence_score", 0) + fp_count = len(state.get("fp_indicators", [])) + tp_count = len(state.get("tp_indicators", [])) + category = state.get("enriched_data", {}).get("category", "").lower() + + # Low confidence with mixed signals + if confidence < 40 and fp_count > 0 and tp_count > 0: + return True + + # Complex attack categories that need investigation + complex_categories = [ + "lateral_movement", + "privilege_escalation", + "persistence", + "command_and_control", + "exfiltration", + "malware" + ] + if any(cat in category for cat in complex_categories): + return True + + # Alerts with tool results need deeper analysis + tool_results = state.get("tool_results", []) + if tool_results: + return True + + return False + + def _should_learn(self, state: WorkflowState) -> bool: + """Check if learning is beneficial.""" + return (state.get("human_feedback") or + any("unusual" in note.lower() for note in state["processing_notes"])) + + # =============================== + # FALLBACK ROUTING (when LLM unavailable) + # =============================== + + def fallback_routing( + self, + state: WorkflowState, + after_node: str + ) -> Literal["correlation", "analysis", "response", "human_loop", "close"]: + """Fallback routing logic when LLM is unavailable or fails.""" + confidence = state["confidence_score"] + fp_count = len(state["fp_indicators"]) + tp_count = len(state["tp_indicators"]) + + if after_node == "triage": + # Fallback rules following SOC best practices + # Close obvious false positives + if confidence <= 10 and fp_count >= 2: + return "close" + + # Fast-track high-confidence threats to response + if confidence >= 85 and tp_count >= 3: + return "response" + + # Prefer correlation first for grey-zone cases (gather context) + if self._needs_correlation(state) or 20 <= confidence <= 80: + return "correlation" + + # Analysis for lower confidence that needs investigation + if confidence < 70: + return "analysis" + + # Medium-high confidence (70-84%) → analyze to confirm + if confidence < 85: + return "analysis" + + # Default fallback: correlation (safest, gathers context) + return "correlation" + + # For other nodes, conservative default + return "close" diff --git a/src/lg_sotf/core/workflow.py b/src/lg_sotf/core/workflow.py index 67799214..b2fc1b11 100644 --- a/src/lg_sotf/core/workflow.py +++ b/src/lg_sotf/core/workflow.py @@ -1,6 +1,8 @@ """ - WorkflowEngine with atomic state management and proper agent coordination. -Fixes the state corruption and duplicate execution issues. +WorkflowEngine for orchestrating multi-agent SOC alert processing. + +This module handles workflow execution, agent coordination, and state management. +The graph structure and routing logic are defined in graph.py. """ import asyncio @@ -8,13 +10,9 @@ from dataclasses import dataclass from datetime import datetime, timedelta from threading import RLock -from typing import Any, Dict, List, TypedDict, Annotated, Literal -import operator +from typing import Any, Dict, List, Literal import json -from pydantic import BaseModel, Field -from langgraph.graph import END, START, StateGraph - from lg_sotf.agents.analysis.base import AnalysisAgent from lg_sotf.agents.correlation.base import CorrelationAgent from lg_sotf.agents.human_loop.base import HumanLoopAgent @@ -28,6 +26,13 @@ from lg_sotf.core.state.model import SOCState, TriageStatus from lg_sotf.utils.llm import get_llm_client +# Import graph components +from lg_sotf.core.graph import ( + WorkflowState, + WorkflowGraphBuilder, + RoutingDecision +) + @dataclass class ExecutionContext: @@ -38,67 +43,8 @@ class ExecutionContext: locks: Dict[str, asyncio.Lock] # Per-node locks -class ExecutionContextData(TypedDict): - """Typed execution context for state.""" - execution_id: str - started_at: str - last_node: str - executed_nodes: List[str] - execution_time: str - - -class RoutingDecision(BaseModel): - """Structured LLM routing decision following LangGraph best practices.""" - next_step: Literal["correlation", "analysis", "response", "human_loop", "close"] = Field( - description="The next processing step for the alert" - ) - confidence: int = Field(ge=0, le=100, description="Confidence in this routing decision") - reasoning: str = Field(description="Brief reasoning for this routing choice") - - -class WorkflowState(TypedDict): - """State schema for LangGraph workflow with proper reducers.""" - # Core identification - alert_id: str - workflow_instance_id: str - execution_context: ExecutionContextData - - # Alert data - raw_alert: Dict[str, Any] - enriched_data: Dict[str, Any] - - # Status and scoring - triage_status: str - confidence_score: int - current_node: str - priority_level: int - - # Indicators - WITH REDUCERS for accumulation - fp_indicators: Annotated[List[str], operator.add] - tp_indicators: Annotated[List[str], operator.add] - - # Correlation data - WITH REDUCER - correlations: Annotated[List[Dict[str, Any]], operator.add] - correlation_score: int - - # Analysis data - analysis_conclusion: str - threat_score: int - recommended_actions: Annotated[List[str], operator.add] - analysis_reasoning: Annotated[List[Dict[str, Any]], operator.add] - tool_results: Dict[str, Dict[str, Any]] - - # Processing tracking - WITH REDUCER - processing_notes: Annotated[List[str], operator.add] - last_updated: str - - # Execution guards - agent_executions: Dict[str, Dict[str, Any]] # Track agent execution state - state_version: int # State versioning for conflict detection - - class WorkflowEngine: - """ workflow engine with atomic state management.""" + """Workflow engine with atomic state management.""" def __init__(self, config_manager: ConfigManager, state_manager: StateManager, redis_storage=None, tool_orchestrator=None): self.config = config_manager @@ -124,29 +70,38 @@ def __init__(self, config_manager: ConfigManager, state_manager: StateManager, r self._execution_contexts = {} # Track active executions self._agent_locks = {} # Per-agent execution locks - # Routing configuration - self.routing_config = { - 'max_alert_age_hours': config_manager.get('routing.max_alert_age_hours', 72), - 'correlation_grey_zone_min': config_manager.get('routing.correlation_grey_zone_min', 30), - 'correlation_grey_zone_max': config_manager.get('routing.correlation_grey_zone_max', 70), - 'analysis_threshold': config_manager.get('routing.analysis_threshold', 40), - 'human_review_min': config_manager.get('routing.human_review_min', 20), - 'human_review_max': config_manager.get('routing.human_review_max', 60), - 'response_threshold': config_manager.get('routing.response_threshold', 80), - } - - self.graph = self._build_workflow_graph() + # Graph builder + self.graph_builder = WorkflowGraphBuilder(config_manager) self.compiled_graph = None + # Recursion limit for safety (prevents infinite loops) + self.recursion_limit = config_manager.get('workflow.recursion_limit', 50) + async def initialize(self): - """Initialize the workflow engine.""" + """Initialize the workflow engine.""" try: await self._setup_agents() - self.compiled_graph = self.graph.compile() - self.logger.info(" WorkflowEngine initialized") + + # Build node executors map for graph construction + node_executors = { + "ingestion": self._execute_ingestion, + "triage": self._execute_triage, + "correlation": self._execute_correlation, + "analysis": self._execute_analysis, + "human_loop": self._execute_human_loop, + "response": self._execute_response, + "learning": self._execute_learning, + "close": self._execute_close, + } + + # Build and compile graph using WorkflowGraphBuilder + graph = self.graph_builder.build_graph(node_executors) + self.compiled_graph = graph.compile() + + self.logger.info("WorkflowEngine initialized successfully") except Exception as e: - self.logger.error(f"Failed to initialize WorkflowEngine: {e}") + self.logger.error(f"Failed to initialize WorkflowEngine: {e}") raise WorkflowError(f"Initialization failed: {e}") async def _setup_agents(self): @@ -257,75 +212,7 @@ def _create_execution_context(self, alert_id: str) -> ExecutionContext: self._execution_contexts[alert_id] = context return context - def _build_workflow_graph(self) -> StateGraph: - """Build the LangGraph workflow.""" - workflow = StateGraph(WorkflowState) - - # Add nodes with synchronization wrappers - workflow.add_node("ingestion", self._execute_ingestion) - workflow.add_node("triage", self._execute_triage) - workflow.add_node("correlation", self._execute_correlation) - workflow.add_node("analysis", self._execute_analysis) - workflow.add_node("human_loop", self._execute_human_loop) - workflow.add_node("response", self._execute_response) - workflow.add_node("learning", self._execute_learning) - workflow.add_node("close", self._execute_close) - - # Set entry point - workflow.add_edge(START, "ingestion") - - # Add conditional edges with routing - workflow.add_conditional_edges( - "ingestion", - self._route_after_ingestion, - {"triage": "triage", "close": "close"}, - ) - - workflow.add_conditional_edges( - "triage", - self._route_after_triage, - { - "correlation": "correlation", - "analysis": "analysis", - "human_loop": "human_loop", - "response": "response", - "close": "close", - }, - ) - - workflow.add_conditional_edges( - "correlation", - self._route_after_correlation, - { - "analysis": "analysis", - "response": "response", - "human_loop": "human_loop", - "close": "close", - }, - ) - - workflow.add_conditional_edges( - "analysis", - self._route_after_analysis, - {"human_loop": "human_loop", "response": "response", "close": "close"}, - ) - - workflow.add_conditional_edges( - "human_loop", - self._route_after_human_loop, - {"analysis": "analysis", "response": "response", "close": "close"}, - ) - - workflow.add_conditional_edges( - "response", - self._route_after_response, - {"learning": "learning", "close": "close"}, - ) - - workflow.add_edge("learning", "close") - workflow.add_edge("close", END) - - return workflow + # Graph building is now handled by WorkflowGraphBuilder in graph.py # =============================== # EXECUTION WRAPPERS @@ -371,15 +258,20 @@ async def _execute_with(self, node_name: str, executor_func, state: WorkflowStat alert_id = state["alert_id"] execution_context = self._execution_contexts.get(alert_id) + # ✅ IMPROVEMENT: Return informative updates instead of empty dict if not execution_context: self.logger.error(f"No execution context for alert {alert_id}") - return {} + return { + "processing_notes": [f"⚠️ Missing execution context for {node_name}"] + } # Check if this node already executed with self._state_lock: if execution_context.node_executions.get(node_name, False): self.logger.warning(f"Node {node_name} already executed for {alert_id}, skipping") - return {} + return { + "processing_notes": [f"⏭️ Skipped {node_name} (already executed)"] + } # Acquire node-specific lock async with execution_context.locks[node_name]: @@ -387,7 +279,9 @@ async def _execute_with(self, node_name: str, executor_func, state: WorkflowStat with self._state_lock: if execution_context.node_executions.get(node_name, False): self.logger.warning(f"Node {node_name} executed during lock wait for {alert_id}") - return {} + return { + "processing_notes": [f"⏭️ Skipped {node_name} (executed during lock wait)"] + } # Mark as executing execution_context.node_executions[node_name] = True @@ -987,9 +881,13 @@ async def execute_workflow(self, alert_id: str, initial_state: Dict[str, Any], s "state_version": 1 } - # Execute through LangGraph - self.logger.info(f"🚀 Starting workflow for {alert_id}") - result_state = await self.compiled_graph.ainvoke(workflow_state) + # Execute through LangGraph with recursion limit for safety + # ✅ IMPROVEMENT: Add recursion_limit to prevent infinite loops + config = { + "recursion_limit": self.recursion_limit + } + self.logger.info(f"🚀 Starting workflow for {alert_id} (recursion_limit={self.recursion_limit})") + result_state = await self.compiled_graph.ainvoke(workflow_state, config) # Cleanup execution context if alert_id in self._execution_contexts: diff --git a/src/lg_sotf/main.py b/src/lg_sotf/main.py index 0e451885..dc355923 100644 --- a/src/lg_sotf/main.py +++ b/src/lg_sotf/main.py @@ -1,833 +1,46 @@ -""" -Main application entry point for LG-SOTF - Production Version. +"""Lightweight main entry point for LG-SOTF.""" -This module provides the main application lifecycle management including: -- Component initialization and dependency injection -- Continuous alert ingestion and processing -- Health monitoring and metrics collection -- Graceful shutdown with resource cleanup -""" - -import asyncio +import argparse import logging -import signal -import sys -from datetime import datetime, timedelta -from pathlib import Path -from typing import Optional, Set - -from lg_sotf.audit.logger import AuditLogger -from lg_sotf.audit.metrics import MetricsCollector -from lg_sotf.core.config.manager import ConfigManager -from lg_sotf.core.exceptions import LG_SOTFError -from lg_sotf.core.state.manager import StateManager -from lg_sotf.core.workflow import WorkflowEngine -from lg_sotf.storage.postgres import PostgreSQLStorage -from lg_sotf.storage.redis import RedisStorage - - -class LG_SOTFApplication: - """Main application class for LG-SOTF.""" - - def __init__(self, config_path: Optional[str] = None, setup_signal_handlers: bool = True): - """Initialize the application. - - Args: - config_path: Path to configuration file - setup_signal_handlers: Whether to setup signal handlers (disable when running under uvicorn) - """ - self.config_path = config_path - self.config_manager = None - self.state_manager = None - self.workflow_engine = None - self.audit_logger = None - self.metrics = None - self.postgres_storage = None - self.redis_storage = None - - # Application state - self.running = False - self.initialized = False - - # Task tracking for graceful shutdown - self._active_tasks: Set[asyncio.Task] = set() - self._shutdown_event = asyncio.Event() - - # Ingestion tracking - self._last_ingestion_poll: Optional[datetime] = None - self._last_health_check: Optional[datetime] = None - self._ingestion_lock = asyncio.Lock() - - # Setup signal handlers (only when not running under uvicorn) - if setup_signal_handlers: - self._setup_signal_handlers() - - def _setup_signal_handlers(self): - """Setup signal handlers for graceful shutdown.""" - def signal_handler(signum, frame): - """Handle shutdown signals.""" - logging.info(f"Received signal {signum}, initiating graceful shutdown...") - self.running = False - # Use call_soon_threadsafe to safely set the event from signal handler - try: - loop = asyncio.get_running_loop() - loop.call_soon_threadsafe(self._shutdown_event.set) - except RuntimeError: - # If no loop is running, just set it directly (shouldn't happen in normal flow) - self._shutdown_event.set() - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - async def initialize(self): - """Initialize all application components. - - Raises: - LG_SOTFError: If initialization fails - """ - try: - logging.info("=== Initializing LG-SOTF Application ===") - - # Load configuration - self.config_manager = ConfigManager(self.config_path) - logging.info("✓ Configuration loaded") - - # Validate configuration - self.config_manager.validate() - logging.info("✓ Configuration validated") - - # Initialize audit and metrics - self.audit_logger = AuditLogger() - self.metrics = MetricsCollector(self.config_manager) - logging.info("✓ Audit and metrics initialized") - - # Initialize storage backends - await self._initialize_storage() - logging.info("✓ Storage backends initialized") - - # Initialize state manager - self.state_manager = StateManager(self.postgres_storage) - logging.info("✓ State manager initialized") - - # Initialize workflow engine (handles agent initialization) - # Pass Redis storage and let the workflow engine create the tool orchestrator - self.workflow_engine = WorkflowEngine( - self.config_manager, - self.state_manager, - redis_storage=self.redis_storage, - tool_orchestrator=None # Created internally by workflow engine - ) - await self.workflow_engine.initialize() - logging.info("✓ Workflow engine initialized") - - # Verify agents are registered - from lg_sotf.agents.registry import agent_registry - stats = agent_registry.get_registry_stats() - logging.info( - f"✓ Agent registry: {stats['agent_types_count']} types, " - f"{stats['agent_instances_count']} instances, " - f"{len(stats['initialized_agents'])} initialized" - ) - - # Log application start - self.audit_logger.log_application_start( - config_path=self.config_path, - version="0.1.0" - ) - - self.initialized = True - logging.info("=== LG-SOTF Application Initialized Successfully ===\n") - - except Exception as e: - logging.error(f"Failed to initialize application: {e}", exc_info=True) - raise LG_SOTFError(f"Application initialization failed: {e}") - - async def _initialize_storage(self): - """Initialize storage backends. - - Raises: - Exception: If storage initialization fails - """ - try: - # Initialize PostgreSQL - db_config = self.config_manager.get_database_config() - connection_string = ( - f"postgresql://{db_config.username}:{db_config.password}@" - f"{db_config.host}:{db_config.port}/{db_config.database}" - ) - - self.postgres_storage = PostgreSQLStorage(connection_string) - await self.postgres_storage.initialize() - logging.info(f" - PostgreSQL connected: {db_config.host}:{db_config.port}") - - # Initialize Redis - redis_config = self.config_manager.get_redis_config() - redis_password = f":{redis_config.password}@" if redis_config.password else "" - redis_connection_string = ( - f"redis://{redis_password}{redis_config.host}:{redis_config.port}/{redis_config.db}" - ) - - self.redis_storage = RedisStorage(redis_connection_string) - await self.redis_storage.initialize() - logging.info(f" - Redis connected: {redis_config.host}:{redis_config.port}") - - except Exception as e: - logging.error(f"Failed to initialize storage: {e}", exc_info=True) - raise - - async def run(self): - """Run the main application loop. - - This method handles: - - Continuous alert ingestion and processing - - Periodic health checks - - Graceful shutdown on signal - """ - try: - self.running = True - logging.info("🚀 LG-SOTF Application Started") - logging.info("Press Ctrl+C to shutdown gracefully\n") - - # Create background tasks - ingestion_task = asyncio.create_task(self._ingestion_loop()) - health_check_task = asyncio.create_task(self._health_check_loop()) - - # Wait for shutdown signal - await self._shutdown_event.wait() - - # Cancel background tasks - logging.info("Stopping background tasks...") - ingestion_task.cancel() - health_check_task.cancel() - - # Wait for tasks to complete - await asyncio.gather(ingestion_task, health_check_task, return_exceptions=True) - - except asyncio.CancelledError: - logging.info("Main loop cancelled") - except Exception as e: - logging.error(f"Unexpected error in main loop: {e}", exc_info=True) - finally: - await self.shutdown() - - async def _ingestion_loop(self): - """Continuous ingestion loop. - - Polls configured sources at regular intervals and processes alerts. - """ - try: - while self.running: - try: - await self._process_alerts() - except asyncio.CancelledError: - raise - except Exception as e: - logging.error(f"Error in ingestion loop: {e}", exc_info=True) - self.metrics.increment_counter("ingestion_loop_errors") - await asyncio.sleep(5) # Back off on error - - # Sleep briefly to prevent tight loop - await asyncio.sleep(1) - - except asyncio.CancelledError: - logging.info("Ingestion loop cancelled") - - async def _health_check_loop(self): - """Periodic health check loop. - - Performs system health checks at regular intervals. - Responsive to shutdown signals. - """ - try: - while self.running: - try: - await self._perform_health_checks() - except asyncio.CancelledError: - raise - except Exception as e: - logging.error(f"Error in health check loop: {e}", exc_info=True) - self.metrics.increment_counter("health_check_errors") - - # Wait before next health check, but wake up on shutdown - try: - await asyncio.wait_for(self._shutdown_event.wait(), timeout=60) - # If we get here, shutdown was triggered - break - except asyncio.TimeoutError: - # Timeout is normal, continue to next health check - pass - - except asyncio.CancelledError: - logging.info("Health check loop cancelled") - - async def _process_alerts(self): - """Process alerts from ingestion sources. - - This method: - - Respects polling interval configuration - - Enforces max concurrent alert limit - - Tracks tasks for graceful shutdown - """ - # Check if ingestion agent is available - if not self.workflow_engine or "ingestion" not in self.workflow_engine.agents: - return - - ingestion_agent = self.workflow_engine.agents["ingestion"] - - # Get polling configuration - ingestion_config = self.config_manager.get_agent_config("ingestion") - polling_interval = ingestion_config.get("polling_interval", 60) - max_concurrent = ingestion_agent.max_concurrent_alerts - - # Check if it's time to poll - if self._last_ingestion_poll is not None: - time_since_poll = (datetime.utcnow() - self._last_ingestion_poll).total_seconds() - if time_since_poll < polling_interval: - return - - # Use lock to prevent concurrent polling - if self._ingestion_lock.locked(): - return - - async with self._ingestion_lock: - try: - # Check active task count - active_count = len(self._active_tasks) - if active_count >= max_concurrent: - logging.warning( - f"Max concurrent alerts reached ({active_count}/{max_concurrent}), " - "skipping this poll cycle" - ) - self.metrics.increment_counter("ingestion_poll_skipped_max_concurrent") - return - - # Poll for new alerts - logging.debug("Polling ingestion sources...") - new_alerts = await ingestion_agent.poll_sources() - - if not new_alerts: - self._last_ingestion_poll = datetime.utcnow() - return - - logging.info(f"📥 Ingestion: Found {len(new_alerts)} new alerts") - self.metrics.increment_counter("ingestion_alerts_received", len(new_alerts)) - - # Process alerts respecting concurrency limit - processed_count = 0 - for alert in new_alerts: - # Check if we can process more alerts - if len(self._active_tasks) >= max_concurrent: - remaining = len(new_alerts) - processed_count - logging.warning( - f"Max concurrent limit reached, " - f"queueing {remaining} alerts for next cycle" - ) - self.metrics.increment_counter("ingestion_alerts_queued", remaining) - break - - try: - # Create workflow task - task = asyncio.create_task( - self._process_single_workflow(alert["id"], alert) - ) - - # Track task - self._active_tasks.add(task) - task.add_done_callback(self._active_tasks.discard) - - processed_count += 1 - - except Exception as e: - logging.error( - f"Failed to create workflow task for alert {alert.get('id', 'unknown')}: {e}", - exc_info=True - ) - self.metrics.increment_counter("workflow_creation_errors") - - logging.info( - f"✓ Created {processed_count} workflow tasks " - f"({len(self._active_tasks)} active)" - ) - self.metrics.set_gauge("active_workflow_tasks", len(self._active_tasks)) - - # Update last poll time - self._last_ingestion_poll = datetime.utcnow() - self.metrics.record_histogram("ingestion_poll_interval", polling_interval) - - except Exception as e: - logging.error(f"Ingestion polling error: {e}", exc_info=True) - self.metrics.increment_counter("ingestion_poll_errors") - - async def _process_single_workflow(self, alert_id: str, alert_data: dict): - """Process a single alert through the workflow. - Args: - alert_id: Alert identifier - alert_data: Alert data dictionary (already ingested by polling loop) - """ - start_time = datetime.utcnow() +import uvicorn - try: - logging.debug(f"Processing workflow for alert {alert_id}") - # Skip ingestion node since alert is already ingested by the polling loop - result = await self.workflow_engine.execute_workflow( - alert_id, - alert_data, - skip_ingestion=True # Alert already normalized by ingestion agent polling - ) - - # Calculate processing time - processing_time = (datetime.utcnow() - start_time).total_seconds() - - logging.info( - f"✓ Alert {alert_id} processed: " - f"status={result.get('triage_status', 'unknown')}, " - f"confidence={result.get('confidence_score', 0)}, " - f"time={processing_time:.2f}s" - ) - - # Record metrics - self.metrics.increment_counter("workflow_success") - self.metrics.record_histogram("workflow_processing_time", processing_time) - - except asyncio.CancelledError: - logging.info(f"Workflow for alert {alert_id} cancelled (shutdown)") - raise - except Exception as e: - processing_time = (datetime.utcnow() - start_time).total_seconds() - logging.error( - f"✗ Failed to process alert {alert_id}: {e}", - exc_info=True - ) - self.metrics.increment_counter("workflow_errors") - self.metrics.record_histogram("workflow_error_time", processing_time) - - async def _perform_health_checks(self): - """Perform periodic health checks on all components.""" - # Check if it's time for health check (every 60 seconds) - if self._last_health_check is not None: - time_since_check = (datetime.utcnow() - self._last_health_check).total_seconds() - if time_since_check < 60: - return - - try: - logging.debug("Performing health checks...") - health_status = await self.health_check() - - if health_status: - logging.debug("✓ All components healthy") - else: - logging.warning("⚠ Some components unhealthy") - - self.metrics.set_gauge("health_check_status", 1 if health_status else 0) - self._last_health_check = datetime.utcnow() - - except Exception as e: - logging.error(f"Health check error: {e}", exc_info=True) - self.metrics.increment_counter("health_check_errors") - - async def shutdown(self): - """Shutdown the application gracefully. - - This method: - - Cancels all active workflow tasks - - Shuts down agents - - Closes storage connections - - Cleans up resources - """ - try: - logging.info("\n=== Shutting Down LG-SOTF Application ===") - - # Cancel active workflow tasks - if self._active_tasks: - task_count = len(self._active_tasks) - logging.info(f"Cancelling {task_count} active workflow tasks...") - - for task in self._active_tasks: - if not task.done(): - task.cancel() - - # Wait for tasks to complete with timeout - try: - await asyncio.wait_for( - asyncio.gather(*self._active_tasks, return_exceptions=True), - timeout=10.0 - ) - logging.info(f"✓ All {task_count} workflow tasks cancelled") - except asyncio.TimeoutError: - logging.warning(f"⚠ Some workflow tasks did not complete within timeout") - - # Log application shutdown - if self.audit_logger: - self.audit_logger.log_application_shutdown() - - # Shutdown agents - await self._shutdown_agents() - - # Close storage connections - await self._shutdown_storage() - - # Shutdown metrics collection - if self.metrics: - try: - self.metrics.shutdown() - logging.info("✓ Metrics collection stopped") - except Exception as e: - logging.warning(f"⚠ Error shutting down metrics: {e}") - - logging.info("=== LG-SOTF Application Shutdown Complete ===\n") - - except Exception as e: - logging.error(f"Error during shutdown: {e}", exc_info=True) - - async def _shutdown_agents(self): - """Shutdown all registered agents.""" - try: - from lg_sotf.agents.registry import agent_registry - - logging.info("Shutting down agents...") - await agent_registry.cleanup_all_agents() - logging.info("✓ All agents stopped") - - except Exception as e: - logging.warning(f"⚠ Error shutting down agents: {e}") - - async def _shutdown_storage(self): - """Shutdown storage connections.""" - storage_tasks = [] - - # Schedule PostgreSQL cleanup - if self.postgres_storage: - try: - storage_tasks.append( - asyncio.create_task(self.postgres_storage.close()) - ) - except Exception as e: - logging.warning(f"⚠ Error scheduling PostgreSQL close: {e}") - - # Schedule Redis cleanup - if self.redis_storage: - try: - storage_tasks.append( - asyncio.create_task(self.redis_storage.close()) - ) - except Exception as e: - logging.warning(f"⚠ Error scheduling Redis close: {e}") - - # Wait for storage cleanup with timeout - if storage_tasks: - try: - await asyncio.wait_for( - asyncio.gather(*storage_tasks, return_exceptions=True), - timeout=5.0 - ) - logging.info("✓ Storage connections closed") - except asyncio.TimeoutError: - logging.warning("⚠ Storage cleanup timed out") - except Exception as e: - logging.warning(f"⚠ Error during storage cleanup: {e}") - - async def health_check(self) -> bool: - """Perform comprehensive health check. - - Returns: - bool: True if all components healthy, False otherwise - """ - try: - health_results = { - 'config_manager': False, - 'state_manager': False, - 'workflow_engine': False, - 'postgres_storage': False, - 'redis_storage': False, - 'agents': False - } - - # Check configuration - if self.config_manager: - health_results['config_manager'] = True - - # Check state manager - if self.state_manager: - health_results['state_manager'] = True - - # Check workflow engine - if self.workflow_engine: - health_results['workflow_engine'] = True - - # Check PostgreSQL - if self.postgres_storage: - health_results['postgres_storage'] = await self.postgres_storage.health_check() - - # Check Redis - if self.redis_storage: - health_results['redis_storage'] = await self.redis_storage.health_check() - - # Check agents - try: - from lg_sotf.agents.registry import agent_registry - - # Check if any agent is healthy - if agent_registry.agent_exists("ingestion_instance"): - ingestion_agent = agent_registry.get_agent("ingestion_instance") - if hasattr(ingestion_agent, 'health_check'): - health_results['agents'] = await ingestion_agent.health_check() - else: - health_results['agents'] = ingestion_agent.initialized - - except Exception as e: - logging.debug(f"Agent health check error: {e}") - health_results['agents'] = False - - # Record component health metrics - if self.metrics: - for component, status in health_results.items(): - self.metrics.set_gauge(f"health_{component}", 1 if status else 0) - - # Calculate overall health - overall_health = all(health_results.values()) - - # Log unhealthy components - unhealthy = [comp for comp, status in health_results.items() if not status] - if unhealthy: - logging.debug(f"Unhealthy components: {', '.join(unhealthy)}") - - return overall_health - - except Exception as e: - logging.error(f"Health check failed: {e}", exc_info=True) - return False - - async def process_single_alert(self, alert_id: str, alert_data: dict) -> dict: - """Process a single alert through the workflow. - - This method is used for testing and manual alert processing. - - Args: - alert_id: Alert identifier - alert_data: Alert data dictionary - - Returns: - dict: Workflow result - - Raises: - LG_SOTFError: If workflow engine not initialized - """ - try: - if not self.workflow_engine: - raise LG_SOTFError("Workflow engine not initialized") - - logging.info(f"Processing single alert: {alert_id}") - - result = await self.workflow_engine.execute_workflow(alert_id, alert_data) - - logging.info(f"Alert {alert_id} processed successfully") - return result - - except Exception as e: - logging.error(f"Failed to process alert {alert_id}: {e}", exc_info=True) - raise - - def get_application_status(self) -> dict: - """Get comprehensive application status. - - Returns: - dict: Application status information - """ - try: - status = { - 'running': self.running, - 'initialized': self.initialized, - 'active_workflow_tasks': len(self._active_tasks), - 'last_ingestion_poll': self._last_ingestion_poll.isoformat() if self._last_ingestion_poll else None, - 'last_health_check': self._last_health_check.isoformat() if self._last_health_check else None, - 'components': { - 'config_manager': self.config_manager is not None, - 'state_manager': self.state_manager is not None, - 'workflow_engine': self.workflow_engine is not None, - 'audit_logger': self.audit_logger is not None, - 'metrics': self.metrics is not None - }, - 'storage': { - 'postgres': self.postgres_storage is not None, - 'redis': self.redis_storage is not None - } - } - - # Add agent status - try: - from lg_sotf.agents.registry import agent_registry - status['agents'] = agent_registry.get_registry_stats() - except Exception as e: - logging.warning(f"Error getting agent status: {e}") - status['agents'] = {'error': str(e)} - - return status - - except Exception as e: - logging.error(f"Error getting application status: {e}", exc_info=True) - return {'error': str(e)} +def main(): + """Run the LG-SOTF API server.""" + # Configure logging + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + # Parse command-line arguments + parser = argparse.ArgumentParser(description="LG-SOTF SOC Dashboard API Server") + parser.add_argument("--config", "-c", default="configs/development.yaml", help="Configuration file path") + parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") + parser.add_argument("--port", type=int, default=8000, help="Port to bind to") + parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development") + args = parser.parse_args() -async def main(): - """Main entry point for LG-SOTF application.""" - import argparse - import json + # Print startup banner + print("🚀 Starting LG-SOTF SOC Dashboard API server...") + print(f"📊 API Documentation: http://{args.host}:{args.port}/api/docs") + print(f"🔌 WebSocket endpoint: ws://{args.host}:{args.port}/ws/{{client_id}}") + print(f"💊 Health check: http://{args.host}:{args.port}/api/v1/health") + print() - # Parse command line arguments - parser = argparse.ArgumentParser( - description="LG-SOTF: LangGraph SOC Triage & Orchestration Framework", - formatter_class=argparse.RawDescriptionHelpFormatter - ) - - parser.add_argument( - "--config", "-c", - type=str, - help="Path to configuration file" - ) - - parser.add_argument( - "--mode", "-m", - choices=["run", "health-check", "process-alert"], - default="run", - help="Application mode (default: run)" - ) - - parser.add_argument( - "--alert-id", - type=str, - help="Alert ID for process-alert mode" - ) - - parser.add_argument( - "--alert-data", - type=str, - help="Alert data JSON file for process-alert mode" - ) - - parser.add_argument( - "--log-level", "-l", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - default="INFO", - help="Logging level (default: INFO)" - ) - - parser.add_argument( - "--version", "-v", - action="version", - version="LG-SOTF 0.1.0" - ) - - args = parser.parse_args() - - # Setup logging - logging.basicConfig( - level=getattr(logging, args.log_level), - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S" + # Run uvicorn server + uvicorn.run( + "lg_sotf.api.app:app", + host=args.host, + port=args.port, + reload=args.reload, + log_level="info", + access_log=True, + ws_ping_interval=20, + ws_ping_timeout=10, ) - - # Create application instance - app = LG_SOTFApplication(config_path=args.config) - - try: - # Initialize application - await app.initialize() - - # Execute based on mode - if args.mode == "health-check": - # Perform health check - health_status = await app.health_check() - status_info = app.get_application_status() - - print("\n" + "=" * 60) - print("🏥 LG-SOTF APPLICATION HEALTH CHECK") - print("=" * 60) - print(f"\nOverall Health: {'✅ HEALTHY' if health_status else '❌ UNHEALTHY'}") - print(f"Running: {'✅ Yes' if status_info['running'] else '❌ No'}") - print(f"Initialized: {'✅ Yes' if status_info['initialized'] else '❌ No'}") - print(f"Active Workflow Tasks: {status_info['active_workflow_tasks']}") - - print("\n" + "-" * 60) - print("🔧 COMPONENTS") - print("-" * 60) - for component, status in status_info['components'].items(): - status_icon = '✅' if status else '❌' - print(f" {status_icon} {component}") - - print("\n" + "-" * 60) - print("🤖 AGENTS") - print("-" * 60) - if 'error' in status_info['agents']: - print(f" ❌ Error: {status_info['agents']['error']}") - else: - agents_info = status_info['agents'] - print(f" Types: {agents_info['agent_types_count']}") - print(f" Instances: {agents_info['agent_instances_count']}") - print(f" Initialized: {len(agents_info['initialized_agents'])}") - if agents_info['initialized_agents']: - for agent_name in agents_info['initialized_agents']: - print(f" ✅ {agent_name}") - - print("\n" + "-" * 60) - print("💾 STORAGE") - print("-" * 60) - for storage_type, status in status_info['storage'].items(): - status_icon = '✅' if status else '❌' - print(f" {status_icon} {storage_type}") - - print("\n" + "=" * 60 + "\n") - - sys.exit(0 if health_status else 1) - - elif args.mode == "process-alert": - # Process a single alert - if not args.alert_id: - print("❌ Error: Alert ID is required for process-alert mode") - print("Usage: python -m lg_sotf.main --mode process-alert --alert-id [--alert-data ]") - sys.exit(1) - - # Load alert data - if args.alert_data: - with open(args.alert_data, 'r') as f: - alert_data = json.load(f) - else: - # Use sample alert data - alert_data = { - "id": args.alert_id, - "source": "manual", - "timestamp": datetime.utcnow().isoformat(), - "severity": "high", - "title": "Manual test alert", - "description": "Test alert for manual processing" - } - - print(f"\n🔄 Processing alert: {args.alert_id}") - result = await app.process_single_alert(args.alert_id, alert_data) - - print("\n" + "=" * 60) - print("✅ ALERT PROCESSED SUCCESSFULLY") - print("=" * 60) - print(f"Alert ID: {args.alert_id}") - print(f"Final Status: {result.get('triage_status', 'unknown')}") - print(f"Confidence Score: {result.get('confidence_score', 0)}") - print(f"Processing Notes: {len(result.get('processing_notes', []))}") - print("=" * 60 + "\n") - - sys.exit(0) - - else: - # Run application in continuous mode - await app.run() - - except KeyboardInterrupt: - logging.info("\nApplication interrupted by user") - sys.exit(0) - except Exception as e: - logging.error(f"Application failed: {e}", exc_info=True) - sys.exit(1) if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + main() diff --git a/tests/unit/test_core/test_state/test_state_manager.py b/tests/unit/test_core/test_state/test_state_manager.py new file mode 100644 index 00000000..4520f013 --- /dev/null +++ b/tests/unit/test_core/test_state/test_state_manager.py @@ -0,0 +1,571 @@ +""" +Comprehensive unit tests for StateManager. + +Tests cover: +- State creation and persistence +- State versioning and history +- State updates with conflict detection +- Historical query methods (correlation support) +- Agent execution tracking +- Workflow history tracking +""" + +import json +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, Mock, MagicMock, patch + +import pytest + +from lg_sotf.core.state.manager import StateManager +from lg_sotf.core.state.model import SOCState, StateVersion, AgentExecution, TriageStatus, AgentExecutionStatus +from lg_sotf.core.exceptions import StateError +from lg_sotf.storage.base import StorageBackend + + +@pytest.fixture +def mock_storage_backend(): + """Create a mock storage backend.""" + storage = Mock(spec=StorageBackend) + storage.save_state = AsyncMock() + storage.get_state = AsyncMock(return_value=None) + storage.get_state_history = AsyncMock(return_value=[]) + # Mock PostgreSQL pool for historical queries + storage.pool = Mock() + storage.pool.acquire = MagicMock() + return storage + + +@pytest.fixture +def state_manager(mock_storage_backend): + """Create a StateManager instance for testing.""" + return StateManager(storage_backend=mock_storage_backend) + + +@pytest.fixture +def sample_raw_alert(): + """Create a sample raw alert.""" + return { + "id": "alert-001", + "source": "test-siem", + "severity": "high", + "title": "Suspicious activity detected", + "timestamp": datetime.utcnow().isoformat(), + "raw_data": { + "source_ip": "192.168.1.100", + "destination_ip": "10.0.0.5", + "user": "admin", + "process": "powershell.exe" + } + } + + +# ======================================== +# STATE CREATION TESTS +# ======================================== + +class TestStateCreation: + """Test state creation and initialization.""" + + @pytest.mark.asyncio + async def test_create_state_success(self, state_manager, sample_raw_alert): + """Test successful state creation.""" + alert_id = "alert-001" + workflow_instance_id = "workflow-001" + + state = await state_manager.create_state( + alert_id=alert_id, + raw_alert=sample_raw_alert, + workflow_instance_id=workflow_instance_id, + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + assert state.alert_id == alert_id + assert state.workflow_instance_id == workflow_instance_id + assert state.raw_alert == sample_raw_alert + assert state.current_node == "ingestion" + # add_version increments state_version, so it becomes 2 after adding the first version + assert state.state_version == 2 + assert len(state.version_history) == 1 + + @pytest.mark.asyncio + async def test_create_state_initializes_version_history(self, state_manager, sample_raw_alert): + """Test that state creation initializes version history.""" + state = await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + assert len(state.version_history) == 1 + first_version = state.version_history[0] + assert first_version.version == 1 + assert first_version.author_type == "system" + assert first_version.author_id == "test_system" + assert "Initial state creation" in first_version.changes_summary + + @pytest.mark.asyncio + async def test_create_state_persists_to_storage(self, state_manager, sample_raw_alert, mock_storage_backend): + """Test that state creation persists to storage.""" + await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + # Verify storage was called + mock_storage_backend.save_state.assert_called_once() + + @pytest.mark.asyncio + async def test_create_state_logs_audit(self, state_manager, sample_raw_alert): + """Test that state creation logs audit trail.""" + with patch.object(state_manager.audit_logger, 'log_state_creation') as mock_log: + state = await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + mock_log.assert_called_once_with(state) + + +# ======================================== +# STATE UPDATE TESTS +# ======================================== + +class TestStateUpdates: + """Test state updates and versioning.""" + + @pytest.mark.asyncio + async def test_update_state_increments_version(self, state_manager, sample_raw_alert): + """Test that updating state increments version.""" + # Create initial state + state = await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + initial_version = state.state_version + + # Update state + updates = {"confidence_score": 75, "triage_status": "triaged"} + updated_state = await state_manager.update_state( + state=state, + updates=updates, + author_type="agent", + author_id="triage_agent", + changes_summary="Triage completed" + ) + + assert updated_state.state_version == initial_version + 1 + + @pytest.mark.asyncio + async def test_update_state_applies_updates(self, state_manager, sample_raw_alert): + """Test that updates are correctly applied to state.""" + state = await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + updates = { + "confidence_score": 85, + "triage_status": "triaged", + "priority_level": 2 + } + + updated_state = await state_manager.update_state( + state=state, + updates=updates, + author_type="agent", + author_id="triage_agent", + changes_summary="Triage completed" + ) + + assert updated_state.confidence_score == 85 + assert updated_state.triage_status == TriageStatus.TRIAGED + assert updated_state.priority_level == 2 + + @pytest.mark.asyncio + async def test_update_state_adds_version_record(self, state_manager, sample_raw_alert): + """Test that state updates add version records.""" + state = await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + initial_version_count = len(state.version_history) + + updates = {"confidence_score": 75} + updated_state = await state_manager.update_state( + state=state, + updates=updates, + author_type="agent", + author_id="triage_agent", + changes_summary="Triage completed" + ) + + assert len(updated_state.version_history) == initial_version_count + 1 + latest_version = updated_state.version_history[-1] + assert latest_version.author_id == "triage_agent" + assert "Triage completed" in latest_version.changes_summary + + @pytest.mark.asyncio + async def test_update_state_persists_changes(self, state_manager, sample_raw_alert, mock_storage_backend): + """Test that state updates are persisted.""" + state = await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + # Reset mock to count only update calls + mock_storage_backend.save_state.reset_mock() + + updates = {"confidence_score": 75} + await state_manager.update_state( + state=state, + updates=updates, + author_type="agent", + author_id="triage_agent", + changes_summary="Triage completed" + ) + + # Verify persistence was called + assert mock_storage_backend.save_state.called + + @pytest.mark.asyncio + async def test_update_state_nested_fields(self, state_manager, sample_raw_alert): + """Test updating nested fields in state.""" + state = await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + updates = { + "enriched_data.threat_intel": {"reputation": "malicious"}, + "metadata.analyst_notes": "Needs investigation" + } + + updated_state = await state_manager.update_state( + state=state, + updates=updates, + author_type="agent", + author_id="correlation_agent", + changes_summary="Added threat intel" + ) + + # Verify nested updates were applied + assert updated_state.enriched_data.get("threat_intel") == {"reputation": "malicious"} + assert updated_state.metadata.get("analyst_notes") == "Needs investigation" + + +# ======================================== +# AGENT EXECUTION TRACKING TESTS +# ======================================== + +class TestAgentExecutionTracking: + """Test agent execution tracking.""" + + @pytest.mark.asyncio + async def test_add_agent_execution(self, state_manager, sample_raw_alert): + """Test adding agent execution record.""" + state = await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + execution = AgentExecution( + agent_name="triage", + execution_id="exec-001", + start_time=datetime.utcnow(), + status=AgentExecutionStatus.COMPLETED, + inputs={"alert_id": "alert-001"}, + outputs={"confidence_score": 75} + ) + + updated_state = await state_manager.add_agent_execution( + state=state, + execution=execution, + author_type="agent", + author_id="triage_agent" + ) + + assert len(updated_state.agent_executions) == 1 + assert updated_state.agent_executions[0].agent_name == "triage" + assert updated_state.agent_executions[0].status == AgentExecutionStatus.COMPLETED + + +# ======================================== +# STATE RETRIEVAL TESTS +# ======================================== + +class TestStateRetrieval: + """Test state retrieval from storage.""" + + @pytest.mark.asyncio + async def test_get_state_success(self, state_manager, mock_storage_backend, sample_raw_alert): + """Test retrieving existing state.""" + # Mock storage to return state data + state_data = { + "alert_id": "alert-001", + "workflow_instance_id": "workflow-001", + "raw_alert": sample_raw_alert, + "enriched_data": {}, + "triage_status": "new", + "confidence_score": 0, + "fp_indicators": [], + "tp_indicators": [], + "current_node": "ingestion", + "next_nodes": ["ingestion"], + "state_version": 1, + "created_at": datetime.utcnow(), + "last_updated": datetime.utcnow(), + "version_history": [], + "agent_executions": [], + "workflow_history": [], + "human_feedback": None, + "escalation_level": 0, + "assigned_analyst": None, + "response_actions": [], + "playbook_executed": None, + "metadata": {}, + "tags": [], + "priority_level": 3 + } + mock_storage_backend.get_state.return_value = state_data + + state = await state_manager.get_state("alert-001", "workflow-001") + + assert state is not None + assert state.alert_id == "alert-001" + + @pytest.mark.asyncio + async def test_get_state_not_found(self, state_manager, mock_storage_backend): + """Test retrieving non-existent state.""" + mock_storage_backend.get_state.return_value = None + + state = await state_manager.get_state("nonexistent", "workflow-999") + + assert state is None + + +# ======================================== +# HISTORICAL QUERY TESTS (CORRELATION SUPPORT) +# ======================================== + +class TestHistoricalQueries: + """Test historical query methods for correlation.""" + + @pytest.mark.asyncio + async def test_query_alerts_by_indicator(self, state_manager, mock_storage_backend): + """Test querying alerts by specific indicator.""" + # Mock PostgreSQL connection and results + mock_conn = AsyncMock() + mock_conn.fetch = AsyncMock(return_value=[ + { + "alert_id": "alert-001", + "workflow_instance_id": "workflow-001", + "state_data": json.dumps({"alert_id": "alert-001"}), + "created_at": datetime.utcnow() + } + ]) + + mock_storage_backend.pool.acquire.return_value.__aenter__.return_value = mock_conn + + results = await state_manager.query_alerts_by_indicator( + indicator_type="source_ip", + indicator_value="192.168.1.100", + time_window_minutes=60, + limit=10 + ) + + assert len(results) > 0 + assert results[0]["alert_id"] == "alert-001" + + @pytest.mark.asyncio + async def test_query_similar_alerts(self, state_manager): + """Test querying similar alerts based on multiple indicators.""" + alert_data = { + "raw_data": { + "source_ip": "192.168.1.100", + "user": "admin", + "file_hash": "abc123" + } + } + + # Mock the query_alerts_by_indicator method + with patch.object(state_manager, 'query_alerts_by_indicator', new_callable=AsyncMock) as mock_query: + mock_query.return_value = [ + { + "alert_id": "alert-002", + "workflow_instance_id": "workflow-002", + "state_data": {}, + "created_at": datetime.utcnow().isoformat() + } + ] + + results = await state_manager.query_similar_alerts( + alert_data=alert_data, + similarity_threshold=0.5, + time_window_minutes=1440, + limit=50 + ) + + # Should have called query_alerts_by_indicator for each indicator + assert mock_query.call_count == 3 # source_ip, user, file_hash + + @pytest.mark.asyncio + async def test_get_alert_frequency(self, state_manager): + """Test getting alert frequency statistics.""" + # Mock query_alerts_by_indicator to return sample alerts + with patch.object(state_manager, 'query_alerts_by_indicator', new_callable=AsyncMock) as mock_query: + mock_query.return_value = [ + {"alert_id": f"alert-{i}", "created_at": datetime.utcnow().isoformat()} + for i in range(5) + ] + + stats = await state_manager.get_alert_frequency( + indicator_type="source_ip", + indicator_value="192.168.1.100", + time_window_minutes=60 + ) + + assert stats["total_count"] == 5 + assert stats["indicator_type"] == "source_ip" + assert stats["indicator_value"] == "192.168.1.100" + assert "alerts_per_hour" in stats + assert stats["alerts_per_hour"] == 5.0 # 5 alerts in 1 hour window + + +# ======================================== +# ERROR HANDLING TESTS +# ======================================== + +class TestErrorHandling: + """Test error handling in state manager.""" + + @pytest.mark.asyncio + async def test_create_state_storage_failure(self, state_manager, sample_raw_alert, mock_storage_backend): + """Test handling of storage failure during state creation.""" + mock_storage_backend.save_state.side_effect = Exception("Storage error") + + with pytest.raises(StateError): + await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + @pytest.mark.asyncio + async def test_update_state_storage_failure(self, state_manager, sample_raw_alert, mock_storage_backend): + """Test handling of storage failure during state update.""" + state = await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + # Make save_state fail on next call + mock_storage_backend.save_state.side_effect = Exception("Storage error") + + with pytest.raises(StateError): + await state_manager.update_state( + state=state, + updates={"confidence_score": 75}, + author_type="agent", + author_id="triage_agent", + changes_summary="Update failed" + ) + + @pytest.mark.asyncio + async def test_get_state_storage_failure(self, state_manager, mock_storage_backend): + """Test handling of storage failure during state retrieval.""" + mock_storage_backend.get_state.side_effect = Exception("Storage error") + + with pytest.raises(StateError): + await state_manager.get_state("alert-001", "workflow-001") + + +# ======================================== +# STATE HASHING TESTS +# ======================================== + +class TestStateHashing: + """Test state hashing for change detection.""" + + @pytest.mark.asyncio + async def test_hash_state_consistency(self, state_manager, sample_raw_alert): + """Test that hashing same state produces same hash.""" + state = await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + hash1 = state_manager._hash_state(state) + hash2 = state_manager._hash_state(state) + + assert hash1 == hash2 + + @pytest.mark.asyncio + async def test_hash_state_detects_changes(self, state_manager, sample_raw_alert): + """Test that hashing detects state changes.""" + state = await state_manager.create_state( + alert_id="alert-001", + raw_alert=sample_raw_alert, + workflow_instance_id="workflow-001", + initial_node="ingestion", + author_type="system", + author_id="test_system" + ) + + hash_before = state_manager._hash_state(state) + + # Modify state + state.confidence_score = 75 + + hash_after = state_manager._hash_state(state) + + assert hash_before != hash_after diff --git a/tests/unit/test_core/test_workflow.py b/tests/unit/test_core/test_workflow.py index cfc4ea17..0ed81735 100644 --- a/tests/unit/test_core/test_workflow.py +++ b/tests/unit/test_core/test_workflow.py @@ -1,68 +1,624 @@ """ -Tests for the workflow engine. +Comprehensive unit tests for WorkflowEngine. + +Tests cover: +- Routing logic (confidence-based decisions) +- State management (atomic updates, versioning) +- Execution guards (locks, duplicate prevention) +- Error handling and recovery +- LLM routing integration """ -from unittest.mock import AsyncMock, Mock +import asyncio +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch, MagicMock import pytest -import pytest_asyncio +from lg_sotf.core.workflow import WorkflowEngine, WorkflowState, ExecutionContext, RoutingDecision from lg_sotf.core.config.manager import ConfigManager from lg_sotf.core.state.manager import StateManager -from lg_sotf.core.workflow import WorkflowEngine -from lg_sotf.storage.postgres import PostgreSQLStorage - - -class TestWorkflowEngine: - """Test cases for WorkflowEngine.""" - - @pytest_asyncio.fixture - async def workflow_engine(self, config_manager, state_manager): - """Create a test workflow engine.""" - config = config_manager # No await, assuming synchronous - state = await state_manager # Await to resolve coroutine - return WorkflowEngine(config, state) - - @pytest.mark.asyncio - async def test_workflow_engine_initialization(self, workflow_engine): - """Test workflow engine initialization.""" - assert workflow_engine.config is not None - assert workflow_engine.state_manager is not None - assert workflow_engine.graph is not None - assert workflow_engine.compiled_graph is not None - - @pytest.mark.asyncio - async def test_execute_workflow_success(self, workflow_engine, sample_alert): - """Test successful workflow execution.""" - workflow_engine.state_manager.create_state = AsyncMock() - workflow_engine.state_manager.create_state.return_value = Mock() - workflow_engine.state_manager.create_state.return_value.dict.return_value = sample_alert - - workflow_engine.compiled_graph.ainvoke = AsyncMock() - workflow_engine.compiled_graph.ainvoke.return_value = {"status": "completed"} - - result = await workflow_engine.execute_workflow("test-alert-001", sample_alert) - - assert result == {"status": "completed"} - workflow_engine.state_manager.create_state.assert_called_once() - workflow_engine.compiled_graph.ainvoke.assert_called_once() - - @pytest.mark.asyncio - async def test_execute_workflow_error(self, workflow_engine, sample_alert): - """Test workflow execution with error.""" - workflow_engine.state_manager.create_state = AsyncMock() - workflow_engine.state_manager.create_state.side_effect = Exception("Test error") - - with pytest.raises(Exception) as exc_info: - await workflow_engine.execute_workflow("test-alert-001", sample_alert) - - assert "Test error" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_build_workflow_graph(self, workflow_engine): - """Test workflow graph building.""" - graph = workflow_engine._build_workflow_graph() # Synchronous method - - assert graph is not None - assert hasattr(graph, 'nodes') - assert hasattr(graph, 'edges') \ No newline at end of file +from lg_sotf.core.state.model import TriageStatus +from lg_sotf.core.exceptions import WorkflowError + + +@pytest.fixture +def mock_config_manager(): + """Create a mock ConfigManager.""" + config = Mock(spec=ConfigManager) + config.get = Mock(side_effect=lambda key, default=None: { + 'workflow.enable_llm_routing': False, # Disable LLM for unit tests + 'routing.max_alert_age_hours': 72, + 'routing.correlation_grey_zone_min': 30, + 'routing.correlation_grey_zone_max': 70, + 'routing.analysis_threshold': 40, + 'routing.human_review_min': 20, + 'routing.human_review_max': 60, + 'routing.response_threshold': 80, + }.get(key, default)) + config.get_agent_config = Mock(return_value={}) + return config + + +@pytest.fixture +def mock_state_manager(): + """Create a mock StateManager.""" + manager = Mock(spec=StateManager) + manager.create_state = AsyncMock(return_value=Mock(alert_id="test-001")) + manager.update_state = AsyncMock() + manager.get_state = AsyncMock(return_value=None) + return manager + + +@pytest.fixture +def mock_redis_storage(): + """Create a mock Redis storage.""" + return Mock() + + +@pytest.fixture +def mock_tool_orchestrator(): + """Create a mock ToolOrchestrator.""" + return Mock() + + +@pytest.fixture +def workflow_engine(mock_config_manager, mock_state_manager, mock_redis_storage, mock_tool_orchestrator): + """Create a WorkflowEngine instance for testing.""" + engine = WorkflowEngine( + config_manager=mock_config_manager, + state_manager=mock_state_manager, + redis_storage=mock_redis_storage, + tool_orchestrator=mock_tool_orchestrator + ) + + # Mock the graph compilation (synchronous) + engine.compiled_graph = Mock() + engine.compiled_graph.ainvoke = AsyncMock() + + # Create mock agents (don't await initialize) + for agent_type in ['ingestion', 'triage', 'correlation', 'analysis', 'human_loop', 'response']: + mock_agent = Mock() + mock_agent.execute = AsyncMock(return_value={}) + mock_agent.initialize = AsyncMock() + engine.agents[agent_type] = mock_agent + engine._agent_locks[agent_type] = asyncio.Lock() + + return engine + + +@pytest.fixture +def sample_workflow_state(): + """Create a sample workflow state.""" + return { + "alert_id": "test-alert-001", + "workflow_instance_id": "workflow-001", + "execution_context": { + "execution_id": "exec-001", + "started_at": datetime.utcnow().isoformat(), + "last_node": "start", + "executed_nodes": [], + "execution_time": datetime.utcnow().isoformat() + }, + "raw_alert": { + "id": "test-alert-001", + "severity": "high", + "title": "Suspicious login", + "raw_data": { + "source_ip": "192.168.1.100", + "user": "admin" + } + }, + "enriched_data": {}, + "triage_status": "new", + "confidence_score": 50, + "current_node": "triage", + "priority_level": 3, + "fp_indicators": [], + "tp_indicators": [], + "correlations": [], + "correlation_score": 0, + "analysis_conclusion": "", + "threat_score": 0, + "recommended_actions": [], + "analysis_reasoning": [], + "tool_results": {}, + "processing_notes": [], + "last_updated": datetime.utcnow().isoformat(), + "agent_executions": {}, + "state_version": 1 + } + + +# ======================================== +# ROUTING LOGIC TESTS +# ======================================== + +class TestRoutingAfterTriage: + """Test routing decisions after triage agent.""" + + @pytest.mark.asyncio + async def test_route_close_low_confidence_with_fp_indicators(self, workflow_engine, sample_workflow_state): + """Test that low confidence + FP indicators routes to close.""" + state = sample_workflow_state.copy() + state["confidence_score"] = 10 + state["fp_indicators"] = ["benign_process", "known_internal_ip"] + + result = await workflow_engine._route_after_triage(state) + + assert result == "close" + + @pytest.mark.asyncio + async def test_route_close_more_fp_than_tp(self, workflow_engine, sample_workflow_state): + """Test that more FP than TP indicators routes to close.""" + state = sample_workflow_state.copy() + state["confidence_score"] = 18 + state["fp_indicators"] = ["indicator1", "indicator2", "indicator3"] + state["tp_indicators"] = ["indicator1"] + + result = await workflow_engine._route_after_triage(state) + + assert result == "close" + + @pytest.mark.asyncio + async def test_route_correlation_moderate_confidence(self, workflow_engine, sample_workflow_state): + """Test that moderate confidence routes to correlation.""" + state = sample_workflow_state.copy() + state["confidence_score"] = 50 + state["fp_indicators"] = [] + state["tp_indicators"] = ["suspicious_activity"] + + result = await workflow_engine._route_after_triage(state) + + # Should route to correlation (builds threat intel) + assert result == "correlation" + + @pytest.mark.asyncio + async def test_route_correlation_with_network_indicators(self, workflow_engine, sample_workflow_state): + """Test that alerts with network indicators route to correlation.""" + state = sample_workflow_state.copy() + state["confidence_score"] = 50 + state["raw_alert"]["raw_data"] = {"source_ip": "192.168.1.100"} + + result = await workflow_engine._route_after_triage(state) + + assert result == "correlation" + + +class TestRoutingAfterCorrelation: + """Test routing decisions after correlation agent.""" + + @pytest.mark.asyncio + async def test_route_response_strong_correlations(self, workflow_engine, sample_workflow_state): + """Test that strong correlations route to response.""" + state = sample_workflow_state.copy() + state["correlation_score"] = 90 + state["correlations"] = [ + {"alert_id": "alert-1"}, + {"alert_id": "alert-2"}, + {"alert_id": "alert-3"}, + {"alert_id": "alert-4"}, + {"alert_id": "alert-5"} + ] + + result = workflow_engine._route_after_correlation(state) + + assert result == "response" + + @pytest.mark.asyncio + async def test_route_analysis_moderate_correlations(self, workflow_engine, sample_workflow_state): + """Test that moderate correlations route to analysis.""" + state = sample_workflow_state.copy() + state["correlation_score"] = 65 + state["correlations"] = [{"alert_id": "alert-1"}, {"alert_id": "alert-2"}, {"alert_id": "alert-3"}] + + result = workflow_engine._route_after_correlation(state) + + assert result == "analysis" + + @pytest.mark.asyncio + async def test_route_close_no_correlations_low_confidence(self, workflow_engine, sample_workflow_state): + """Test that no correlations + low confidence routes to close.""" + state = sample_workflow_state.copy() + state["confidence_score"] = 25 + state["correlation_score"] = 15 + state["correlations"] = [] + + result = workflow_engine._route_after_correlation(state) + + assert result == "close" + + @pytest.mark.asyncio + async def test_route_human_loop_weak_correlations_high_confidence(self, workflow_engine, sample_workflow_state): + """Test that weak correlations + high confidence routes to human loop.""" + state = sample_workflow_state.copy() + state["confidence_score"] = 55 + state["correlation_score"] = 25 + state["correlations"] = [] + + result = workflow_engine._route_after_correlation(state) + + assert result == "human_loop" + + +class TestRoutingAfterAnalysis: + """Test routing decisions after analysis agent.""" + + @pytest.mark.asyncio + async def test_route_response_high_threat_score(self, workflow_engine, sample_workflow_state): + """Test that high threat score routes to response.""" + state = sample_workflow_state.copy() + state["threat_score"] = 85 + state["confidence_score"] = 85 + + result = workflow_engine._route_after_analysis(state) + + assert result == "response" + + @pytest.mark.asyncio + async def test_route_close_low_threat(self, workflow_engine, sample_workflow_state): + """Test that low threat score routes to close.""" + state = sample_workflow_state.copy() + state["threat_score"] = 25 + state["confidence_score"] = 35 + + result = workflow_engine._route_after_analysis(state) + + # With confidence 35, this might route to human_loop based on grey zone logic + assert result in ["close", "human_loop"] + + @pytest.mark.asyncio + async def test_route_human_loop_uncertain_conclusion(self, workflow_engine, sample_workflow_state): + """Test that uncertain analysis routes to human loop.""" + state = sample_workflow_state.copy() + state["threat_score"] = 50 + state["confidence_score"] = 55 + state["analysis_conclusion"] = "Analysis is uncertain about threat nature" + + result = workflow_engine._route_after_analysis(state) + + assert result == "human_loop" + + +# ======================================== +# STATE MANAGEMENT TESTS +# ======================================== + +class TestStateManagement: + """Test state management and atomic updates.""" + + @pytest.mark.asyncio + async def test_execution_context_creation(self, workflow_engine): + """Test that execution context is created correctly.""" + alert_id = "test-alert-001" + + context = workflow_engine._create_execution_context(alert_id) + + assert context.execution_id.startswith(alert_id) + assert context.started_at is not None + assert len(context.locks) == 7 # All node locks + assert context.node_executions == {} + + @pytest.mark.asyncio + async def test_convert_to_agent_format(self, workflow_engine, sample_workflow_state): + """Test conversion of workflow state to agent format.""" + state = sample_workflow_state.copy() + + agent_input = workflow_engine._convert_to_agent_format(state) + + assert agent_input["alert_id"] == state["alert_id"] + assert agent_input["raw_alert"] == state["raw_alert"] + assert agent_input["confidence_score"] == state["confidence_score"] + assert agent_input["fp_indicators"] == state["fp_indicators"] + assert "metadata" in agent_input + + +# ======================================== +# EXECUTION GUARDS TESTS +# ======================================== + +class TestExecutionGuards: + """Test duplicate execution prevention and locks.""" + + @pytest.mark.asyncio + async def test_duplicate_execution_prevented(self, workflow_engine, sample_workflow_state): + """Test that duplicate agent execution is prevented.""" + state = sample_workflow_state.copy() + alert_id = state["alert_id"] + + # Create execution context and mark triage as executed + context = workflow_engine._create_execution_context(alert_id) + context.node_executions["triage"] = True + + # Mark agent as already executed in state + state["agent_executions"][f"triage_{alert_id}"] = { + "executed_at": datetime.utcnow().isoformat(), + "status": "completed" + } + + # Attempt to execute triage again + updates = await workflow_engine._execute_triage(state) + + # Should skip execution since it was already marked as completed + assert updates == {} + + @pytest.mark.asyncio + async def test_concurrent_execution_prevented_by_locks(self, workflow_engine, sample_workflow_state): + """Test that concurrent execution is prevented by locks.""" + state1 = sample_workflow_state.copy() + state2 = sample_workflow_state.copy() + alert_id = state1["alert_id"] + + # Create execution context + workflow_engine._create_execution_context(alert_id) + + # Mock agent to simulate slow execution + async def slow_execute(input_state): + await asyncio.sleep(0.05) + return {"confidence_score": 50, "fp_indicators": [], "tp_indicators": []} + + workflow_engine.agents["triage"].execute = slow_execute + + # Start two concurrent executions with separate state copies + task1 = asyncio.create_task(workflow_engine._execute_triage(state1)) + task2 = asyncio.create_task(workflow_engine._execute_triage(state2)) + + results = await asyncio.gather(task1, task2) + + # First execution should complete, second should be prevented by node_executions tracking + # Both might return results but the execution guard should prevent actual double execution + assert len(results) == 2 # Both tasks complete + # Verify at least one has results (the first to acquire the lock) + has_results = [bool(r) for r in results] + assert any(has_results) + + +# ======================================== +# ERROR HANDLING TESTS +# ======================================== + +class TestErrorHandling: + """Test error handling and recovery.""" + + @pytest.mark.asyncio + async def test_agent_execution_error_handling(self, workflow_engine, sample_workflow_state): + """Test that agent execution errors are handled gracefully.""" + state = sample_workflow_state.copy() + + # Create execution context + workflow_engine._create_execution_context(state["alert_id"]) + + # Mock agent to raise an error + workflow_engine.agents["triage"].execute = AsyncMock(side_effect=Exception("Agent failed")) + + updates = await workflow_engine._execute_triage(state) + + # Should return error in processing notes + assert "processing_notes" in updates + assert any("error" in note.lower() or "failed" in note.lower() for note in updates["processing_notes"]) + + @pytest.mark.asyncio + async def test_missing_execution_context_handling(self, workflow_engine, sample_workflow_state): + """Test handling of missing execution context.""" + state = sample_workflow_state.copy() + + # Don't create execution context (simulate error condition) + # Attempt to execute node + updates = await workflow_engine._execute_with("triage", workflow_engine._execute_triage, state) + + # Should handle gracefully and return informative feedback + assert "processing_notes" in updates + assert any("Missing execution context" in note for note in updates["processing_notes"]) + + @pytest.mark.asyncio + async def test_workflow_initialization_failure(self, mock_config_manager, mock_state_manager): + """Test that workflow initialization failures are caught.""" + engine = WorkflowEngine( + config_manager=mock_config_manager, + state_manager=mock_state_manager, + redis_storage=None, + tool_orchestrator=None + ) + + # Mock agent setup to fail + with patch.object(engine, '_setup_agents', side_effect=Exception("Setup failed")): + with pytest.raises(WorkflowError): + await engine.initialize() + + +# ======================================== +# ROUTING HELPER TESTS +# ======================================== + +class TestRoutingHelpers: + """Test routing helper methods.""" + + @pytest.mark.asyncio + async def test_needs_correlation_with_network_indicators(self, workflow_engine, sample_workflow_state): + """Test detection of network indicators.""" + state = sample_workflow_state.copy() + state["enriched_data"]["source_ip"] = "192.168.1.100" + + result = workflow_engine._needs_correlation(state) + + assert result is True + + @pytest.mark.asyncio + async def test_needs_correlation_with_user_indicators(self, workflow_engine, sample_workflow_state): + """Test detection of user indicators.""" + state = sample_workflow_state.copy() + # Set user at the raw_alert level (not raw_data) + state["raw_alert"]["user"] = "admin" + + result = workflow_engine._needs_correlation(state) + + assert result is True + + @pytest.mark.asyncio + async def test_needs_correlation_with_file_indicators(self, workflow_engine, sample_workflow_state): + """Test detection of file indicators.""" + state = sample_workflow_state.copy() + state["enriched_data"]["file_hash"] = "abc123" + + result = workflow_engine._needs_correlation(state) + + assert result is True + + @pytest.mark.asyncio + async def test_needs_analysis_low_confidence_mixed_signals(self, workflow_engine, sample_workflow_state): + """Test analysis needed for low confidence with mixed signals.""" + state = sample_workflow_state.copy() + state["confidence_score"] = 35 + state["fp_indicators"] = ["indicator1"] + state["tp_indicators"] = ["indicator2"] + + result = workflow_engine._needs_analysis(state) + + assert result is True + + @pytest.mark.asyncio + async def test_needs_analysis_complex_category(self, workflow_engine, sample_workflow_state): + """Test analysis needed for complex attack categories.""" + state = sample_workflow_state.copy() + state["enriched_data"]["category"] = "lateral_movement" + + result = workflow_engine._needs_analysis(state) + + assert result is True + + +# ======================================== +# INTEGRATION TESTS (WITHIN WORKFLOW ENGINE) +# ======================================== + +class TestWorkflowIntegration: + """Test integrated workflow execution.""" + + @pytest.mark.asyncio + async def test_triage_execution_updates_state(self, workflow_engine, sample_workflow_state): + """Test that triage execution properly updates state.""" + state = sample_workflow_state.copy() + + # Create execution context + workflow_engine._create_execution_context(state["alert_id"]) + + # Mock triage agent + workflow_engine.agents["triage"].execute = AsyncMock(return_value={ + "confidence_score": 75, + "fp_indicators": [], + "tp_indicators": ["suspicious_login", "unusual_time"], + "priority_level": 2, + "triage_status": "triaged" + }) + + updates = await workflow_engine._execute_triage(state) + + assert updates["confidence_score"] == 75 + assert len(updates["tp_indicators"]) == 2 + assert updates["triage_status"] == "triaged" + + @pytest.mark.asyncio + async def test_correlation_execution_updates_state(self, workflow_engine, sample_workflow_state): + """Test that correlation execution properly updates state.""" + state = sample_workflow_state.copy() + + # Create execution context + workflow_engine._create_execution_context(state["alert_id"]) + + # Mock correlation agent + workflow_engine.agents["correlation"].execute = AsyncMock(return_value={ + "correlations": [ + {"alert_id": "alert-1", "similarity": 0.8}, + {"alert_id": "alert-2", "similarity": 0.7} + ], + "correlation_score": 75, + "confidence_score": 65 + }) + + updates = await workflow_engine._execute_correlation(state) + + assert len(updates["correlations"]) == 2 + assert updates["correlation_score"] == 75 + + @pytest.mark.asyncio + async def test_analysis_execution_updates_state(self, workflow_engine, sample_workflow_state): + """Test that analysis execution properly updates state.""" + state = sample_workflow_state.copy() + + # Create execution context + workflow_engine._create_execution_context(state["alert_id"]) + + # Mock analysis agent + workflow_engine.agents["analysis"].execute = AsyncMock(return_value={ + "threat_score": 80, + "analysis_conclusion": "Confirmed threat", + "recommended_actions": ["isolate_host", "block_ip"], + "analysis_reasoning": [{"step": 1, "action": "checked_threat_intel"}], + "tool_results": {"virustotal": {"malicious": True}} + }) + + updates = await workflow_engine._execute_analysis(state) + + assert updates["threat_score"] == 80 + assert updates["analysis_conclusion"] == "Confirmed threat" + assert len(updates["recommended_actions"]) == 2 + + +# ======================================== +# WORKFLOW METRICS TESTS +# ======================================== + +class TestWorkflowMetrics: + """Test workflow metrics collection.""" + + @pytest.mark.asyncio + async def test_get_workflow_metrics(self, workflow_engine): + """Test that workflow metrics are collected correctly.""" + metrics = workflow_engine.get_workflow_metrics() + + assert "active_executions" in metrics + assert "total_agents" in metrics + assert "agent_locks" in metrics + assert metrics["synchronization_enabled"] is True + assert metrics["total_agents"] == len(workflow_engine.agents) + + +# ======================================== +# LLM ROUTING TESTS (FALLBACK) +# ======================================== + +class TestLLMRouting: + """Test LLM-enhanced routing (fallback mode).""" + + @pytest.mark.asyncio + async def test_fallback_routing_after_triage_close_obvious_fp(self, workflow_engine, sample_workflow_state): + """Test fallback routing closes obvious FPs.""" + state = sample_workflow_state.copy() + state["confidence_score"] = 8 + state["fp_indicators"] = ["benign1", "benign2"] + + result = workflow_engine._fallback_routing(state, "triage") + + assert result == "close" + + @pytest.mark.asyncio + async def test_fallback_routing_after_triage_response_high_confidence(self, workflow_engine, sample_workflow_state): + """Test fallback routing sends high confidence to response.""" + state = sample_workflow_state.copy() + state["confidence_score"] = 90 + state["tp_indicators"] = ["mal1", "mal2", "mal3"] + + result = workflow_engine._fallback_routing(state, "triage") + + assert result == "response" + + @pytest.mark.asyncio + async def test_fallback_routing_prefers_correlation(self, workflow_engine, sample_workflow_state): + """Test fallback routing prefers correlation for grey zone.""" + state = sample_workflow_state.copy() + state["confidence_score"] = 50 + state["raw_alert"]["raw_data"]["source_ip"] = "192.168.1.100" + + result = workflow_engine._fallback_routing(state, "triage") + + assert result == "correlation"