diff --git a/state-manager/app/controller/cancel_triggers.py b/state-manager/app/controller/cancel_triggers.py new file mode 100644 index 00000000..903a9980 --- /dev/null +++ b/state-manager/app/controller/cancel_triggers.py @@ -0,0 +1,67 @@ +""" +Controller for cancelling pending triggers for a graph +""" +import asyncio +from app.models.cancel_trigger_models import CancelTriggerResponse +from app.models.db.trigger import DatabaseTriggers +from app.models.trigger_models import TriggerStatusEnum +from app.singletons.logs_manager import LogsManager +from app.config.settings import get_settings +from app.tasks.trigger_cron import mark_as_cancelled +from beanie.operators import In + +logger = LogsManager().get_logger() + +async def cancel_triggers(namespace_name: str, graph_name: str, x_exosphere_request_id: str) -> CancelTriggerResponse: + """ + Cancel all pending or triggering triggers for a specific graph + + Args: + namespace_name: The namespace of the graph + graph_name: The name of the graph + x_exosphere_request_id: Request ID for logging + + Returns: + CancelTriggerResponse with cancellation details + """ + try: + logger.info(f"Request to cancel triggers for graph {graph_name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + + # Find all PENDING or TRIGGERING triggers for this graph + triggers = await DatabaseTriggers.find( + DatabaseTriggers.namespace == namespace_name, + DatabaseTriggers.graph_name == graph_name, + In(DatabaseTriggers.trigger_status, [TriggerStatusEnum.PENDING, TriggerStatusEnum.TRIGGERING]) + ).to_list() + + if not triggers: + logger.info(f"No pending triggers found for graph {graph_name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + return CancelTriggerResponse( + namespace=namespace_name, + graph_name=graph_name, + cancelled_count=0, + message="No pending triggers found to cancel" + ) + + # Get retention hours from settings + settings = get_settings() + retention_hours = settings.trigger_retention_hours + + # Cancel each trigger concurrently + cancelled_count = len(triggers) + cancellation_tasks = [mark_as_cancelled(trigger, retention_hours) for trigger in triggers] + await asyncio.gather(*cancellation_tasks) + + logger.info(f"Cancelled {cancelled_count} triggers for graph {graph_name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + + return CancelTriggerResponse( + namespace=namespace_name, + graph_name=graph_name, + cancelled_count=cancelled_count, + message=f"Successfully cancelled {cancelled_count} trigger(s)" + ) + + except Exception as e: + logger.error(f"Error cancelling triggers for graph {graph_name} in namespace {namespace_name}: {str(e)}", x_exosphere_request_id=x_exosphere_request_id) + raise + diff --git a/state-manager/app/models/cancel_trigger_models.py b/state-manager/app/models/cancel_trigger_models.py new file mode 100644 index 00000000..7c52bc13 --- /dev/null +++ b/state-manager/app/models/cancel_trigger_models.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel, Field + + +class CancelTriggerResponse(BaseModel): + namespace: str = Field(..., description="Namespace of the cancelled triggers") + graph_name: str = Field(..., description="Name of the graph") + cancelled_count: int = Field(..., description="Number of triggers that were cancelled") + message: str = Field(..., description="Human-readable message describing the result") + diff --git a/state-manager/app/models/db/trigger.py b/state-manager/app/models/db/trigger.py index a416193e..9a4998ce 100644 --- a/state-manager/app/models/db/trigger.py +++ b/state-manager/app/models/db/trigger.py @@ -44,7 +44,8 @@ class Settings: "trigger_status": { "$in": [ TriggerStatusEnum.TRIGGERED, - TriggerStatusEnum.FAILED + TriggerStatusEnum.FAILED, + TriggerStatusEnum.CANCELLED ] } } diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index 03fca45a..4de7ca46 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -54,6 +54,10 @@ from .models.manual_retry import ManualRetryRequestModel, ManualRetryResponseModel from .controller.manual_retry_state import manual_retry_state +# cancel_triggers +from .models.cancel_trigger_models import CancelTriggerResponse +from .controller.cancel_triggers import cancel_triggers + logger = LogsManager().get_logger() @@ -237,6 +241,25 @@ async def get_graph_template(namespace_name: str, graph_name: str, request: Requ return await get_graph_template_controller(namespace_name, graph_name, x_exosphere_request_id) +@router.delete( + "/graph/{graph_name}/triggers", + response_model=CancelTriggerResponse, + status_code=status.HTTP_200_OK, + response_description="Triggers cancelled successfully", + tags=["graph"] +) +async def cancel_triggers_route(namespace_name: str, graph_name: str, request: Request, api_key: str = Depends(check_api_key)): + x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4())) + + if api_key: + logger.info(f"API key is valid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + else: + logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + + return await cancel_triggers(namespace_name, graph_name, x_exosphere_request_id) + + @router.put( "/nodes/", response_model=RegisterNodesResponseModel, diff --git a/state-manager/app/tasks/trigger_cron.py b/state-manager/app/tasks/trigger_cron.py index 3fca36ea..48771281 100644 --- a/state-manager/app/tasks/trigger_cron.py +++ b/state-manager/app/tasks/trigger_cron.py @@ -98,4 +98,15 @@ async def trigger_cron(): cron_time = datetime.now() settings = get_settings() logger.info(f"starting trigger_cron: {cron_time}") - await asyncio.gather(*[handle_trigger(cron_time, settings.trigger_retention_hours) for _ in range(settings.trigger_workers)]) \ No newline at end of file + await asyncio.gather(*[handle_trigger(cron_time, settings.trigger_retention_hours) for _ in range(settings.trigger_workers)]) + +async def mark_as_cancelled(trigger: DatabaseTriggers, retention_hours: int): + expires_at = datetime.now(timezone.utc) + timedelta(hours=retention_hours) + + await DatabaseTriggers.get_pymongo_collection().update_one( + {"_id": trigger.id}, + {"$set": { + "trigger_status": TriggerStatusEnum.CANCELLED, + "expires_at": expires_at + }} + ) \ No newline at end of file diff --git a/state-manager/tests/unit/controller/test_cancel_triggers.py b/state-manager/tests/unit/controller/test_cancel_triggers.py new file mode 100644 index 00000000..9878669c --- /dev/null +++ b/state-manager/tests/unit/controller/test_cancel_triggers.py @@ -0,0 +1,273 @@ +""" +Tests for cancel_triggers controller. +Verifies cancellation of pending and triggering triggers for a graph. +""" +import pytest +from unittest.mock import patch, MagicMock, AsyncMock + +from app.controller.cancel_triggers import cancel_triggers +from app.models.cancel_trigger_models import CancelTriggerResponse +from app.models.db.trigger import DatabaseTriggers + + +@pytest.mark.asyncio +async def test_cancel_triggers_success_with_pending(): + """Test successfully cancelling PENDING triggers""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + # Create mock triggers + mock_trigger1 = MagicMock(spec=DatabaseTriggers) + mock_trigger1.id = "trigger_id_1" + mock_trigger2 = MagicMock(spec=DatabaseTriggers) + mock_trigger2.id = "trigger_id_2" + + with patch('app.controller.cancel_triggers.DatabaseTriggers') as mock_db, \ + patch('app.controller.cancel_triggers.get_settings') as mock_get_settings, \ + patch('app.controller.cancel_triggers.mark_as_cancelled') as mock_mark_cancelled: + + # Setup mock database query + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[mock_trigger1, mock_trigger2]) + mock_db.find.return_value = mock_query + + # Setup mock settings + mock_settings = MagicMock() + mock_settings.trigger_retention_hours = 24 + mock_get_settings.return_value = mock_settings + + result = await cancel_triggers(namespace_name, graph_name, x_exosphere_request_id) + + # Verify result + assert isinstance(result, CancelTriggerResponse) + assert result.namespace == namespace_name + assert result.graph_name == graph_name + assert result.cancelled_count == 2 + assert "Successfully cancelled 2 trigger(s)" in result.message + + # Verify mark_as_cancelled was called for each trigger + assert mock_mark_cancelled.call_count == 2 + + +@pytest.mark.asyncio +async def test_cancel_triggers_success_with_triggering(): + """Test successfully cancelling TRIGGERING triggers""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + # Create mock trigger + mock_trigger = MagicMock(spec=DatabaseTriggers) + mock_trigger.id = "trigger_id_1" + + with patch('app.controller.cancel_triggers.DatabaseTriggers') as mock_db, \ + patch('app.controller.cancel_triggers.get_settings') as mock_get_settings, \ + patch('app.controller.cancel_triggers.mark_as_cancelled') as mock_mark_cancelled: + + # Setup mock database query + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[mock_trigger]) + mock_db.find.return_value = mock_query + + # Setup mock settings + mock_settings = MagicMock() + mock_settings.trigger_retention_hours = 12 + mock_get_settings.return_value = mock_settings + + result = await cancel_triggers(namespace_name, graph_name, x_exosphere_request_id) + + # Verify result + assert isinstance(result, CancelTriggerResponse) + assert result.cancelled_count == 1 + assert "Successfully cancelled 1 trigger(s)" in result.message + + # Verify mark_as_cancelled was called with retention_hours from settings + mock_mark_cancelled.assert_called_once_with(mock_trigger, 12) + + +@pytest.mark.asyncio +async def test_cancel_triggers_no_triggers_found(): + """Test cancelling triggers when no pending triggers exist""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + with patch('app.controller.cancel_triggers.DatabaseTriggers') as mock_db: + # Setup mock database query to return empty list + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_db.find.return_value = mock_query + + result = await cancel_triggers(namespace_name, graph_name, x_exosphere_request_id) + + # Verify result + assert isinstance(result, CancelTriggerResponse) + assert result.namespace == namespace_name + assert result.graph_name == graph_name + assert result.cancelled_count == 0 + assert "No pending triggers found to cancel" in result.message + + +@pytest.mark.asyncio +async def test_cancel_triggers_query_filters_correctly(): + """Test that the query filters by namespace, graph_name, and status""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + with patch('app.controller.cancel_triggers.DatabaseTriggers') as mock_db, \ + patch('app.controller.cancel_triggers.get_settings') as mock_get_settings, \ + patch('app.controller.cancel_triggers.mark_as_cancelled'): + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_db.find.return_value = mock_query + + mock_settings = MagicMock() + mock_settings.trigger_retention_hours = 24 + mock_get_settings.return_value = mock_settings + + await cancel_triggers(namespace_name, graph_name, x_exosphere_request_id) + + # Verify find was called with correct arguments + mock_db.find.assert_called_once() + call_args = mock_db.find.call_args + + # Check that all three conditions are in the call + # The call should include namespace, graph_name, and In for trigger_status + assert call_args is not None + + +@pytest.mark.asyncio +async def test_cancel_triggers_uses_settings_retention_hours(): + """Test that the function uses retention hours from settings""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + mock_trigger = MagicMock(spec=DatabaseTriggers) + mock_trigger.id = "trigger_id_1" + + with patch('app.controller.cancel_triggers.DatabaseTriggers') as mock_db, \ + patch('app.controller.cancel_triggers.get_settings') as mock_get_settings, \ + patch('app.controller.cancel_triggers.mark_as_cancelled') as mock_mark_cancelled: + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[mock_trigger]) + mock_db.find.return_value = mock_query + + mock_settings = MagicMock() + mock_settings.trigger_retention_hours = 48 + mock_get_settings.return_value = mock_settings + + await cancel_triggers(namespace_name, graph_name, x_exosphere_request_id) + + # Verify mark_as_cancelled was called with correct retention_hours + mock_mark_cancelled.assert_called_once_with(mock_trigger, 48) + + +@pytest.mark.asyncio +async def test_cancel_triggers_handles_database_error(): + """Test that database errors are properly logged and re-raised""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + with patch('app.controller.cancel_triggers.DatabaseTriggers') as mock_db: + mock_db.find.side_effect = Exception("Database connection error") + + with pytest.raises(Exception, match="Database connection error"): + await cancel_triggers(namespace_name, graph_name, x_exosphere_request_id) + + +@pytest.mark.asyncio +async def test_cancel_triggers_handles_mark_error(): + """Test that errors during marking are properly handled""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + mock_trigger = MagicMock(spec=DatabaseTriggers) + mock_trigger.id = "trigger_id_1" + + with patch('app.controller.cancel_triggers.DatabaseTriggers') as mock_db, \ + patch('app.controller.cancel_triggers.get_settings') as mock_get_settings, \ + patch('app.controller.cancel_triggers.mark_as_cancelled') as mock_mark_cancelled: + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[mock_trigger]) + mock_db.find.return_value = mock_query + + mock_settings = MagicMock() + mock_settings.trigger_retention_hours = 24 + mock_get_settings.return_value = mock_settings + + mock_mark_cancelled.side_effect = Exception("Failed to update trigger") + + with pytest.raises(Exception, match="Failed to update trigger"): + await cancel_triggers(namespace_name, graph_name, x_exosphere_request_id) + + +@pytest.mark.asyncio +async def test_cancel_triggers_multiple_triggers_batch(): + """Test that multiple triggers are cancelled in batch""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + mock_trigger1 = MagicMock(spec=DatabaseTriggers) + mock_trigger1.id = "trigger_id_1" + mock_trigger2 = MagicMock(spec=DatabaseTriggers) + mock_trigger2.id = "trigger_id_2" + mock_trigger3 = MagicMock(spec=DatabaseTriggers) + mock_trigger3.id = "trigger_id_3" + + with patch('app.controller.cancel_triggers.DatabaseTriggers') as mock_db, \ + patch('app.controller.cancel_triggers.get_settings') as mock_get_settings, \ + patch('app.controller.cancel_triggers.mark_as_cancelled') as mock_mark_cancelled: + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[mock_trigger1, mock_trigger2, mock_trigger3]) + mock_db.find.return_value = mock_query + + mock_settings = MagicMock() + mock_settings.trigger_retention_hours = 24 + mock_get_settings.return_value = mock_settings + + result = await cancel_triggers(namespace_name, graph_name, x_exosphere_request_id) + + # Verify correct count + assert result.cancelled_count == 3 + + # Verify all triggers were processed + assert mock_mark_cancelled.call_count == 3 + +@pytest.mark.asyncio +async def test_cancel_triggers_calls_get_settings(): + """Test that get_settings is called when cancelling triggers""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + mock_trigger = MagicMock(spec=DatabaseTriggers) + mock_trigger.id = "trigger_id_1" + + with patch('app.controller.cancel_triggers.DatabaseTriggers') as mock_db, \ + patch('app.controller.cancel_triggers.get_settings') as mock_get_settings, \ + patch('app.controller.cancel_triggers.mark_as_cancelled') as mock_mark_cancelled: + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[mock_trigger]) + mock_db.find.return_value = mock_query + + mock_settings = MagicMock() + mock_settings.trigger_retention_hours = 24 + mock_get_settings.return_value = mock_settings + + await cancel_triggers(namespace_name, graph_name, x_exosphere_request_id) + + # Verify get_settings was called (only when there are triggers to cancel) + mock_get_settings.assert_called_once() + # Verify it was called with retention_hours from settings + mock_mark_cancelled.assert_called_once_with(mock_trigger, 24) diff --git a/state-manager/tests/unit/tasks/test_trigger_cron.py b/state-manager/tests/unit/tasks/test_trigger_cron.py index 7179e08a..95f0aa1c 100644 --- a/state-manager/tests/unit/tasks/test_trigger_cron.py +++ b/state-manager/tests/unit/tasks/test_trigger_cron.py @@ -10,6 +10,7 @@ from app.tasks.trigger_cron import ( mark_as_triggered, mark_as_failed, + mark_as_cancelled, get_due_triggers, call_trigger_graph, create_next_triggers, @@ -24,6 +25,7 @@ @pytest.mark.parametrize("mark_function,expected_status", [ (mark_as_triggered, TriggerStatusEnum.TRIGGERED), (mark_as_failed, TriggerStatusEnum.FAILED), + (mark_as_cancelled, TriggerStatusEnum.CANCELLED), ]) async def test_mark_trigger_sets_expires_at(mark_function, expected_status): """Test that marking a trigger sets the expires_at field correctly""" @@ -69,6 +71,9 @@ async def test_mark_trigger_sets_expires_at(mark_function, expected_status): (mark_as_failed, 12), (mark_as_failed, 24), (mark_as_failed, 48), + (mark_as_cancelled, 12), + (mark_as_cancelled, 24), + (mark_as_cancelled, 48), ]) async def test_mark_trigger_uses_custom_retention_period(mark_function, retention_hours): """Test that custom retention period is respected across all mark functions""" diff --git a/state-manager/tests/unit/test_routes.py b/state-manager/tests/unit/test_routes.py index 2477c18d..8495ab1f 100644 --- a/state-manager/tests/unit/test_routes.py +++ b/state-manager/tests/unit/test_routes.py @@ -9,6 +9,7 @@ from app.models.list_models import ListRegisteredNodesResponse, ListGraphTemplatesResponse from app.models.run_models import RunsResponse, RunListItem, RunStatusEnum from app.models.manual_retry import ManualRetryRequestModel, ManualRetryResponseModel +from app.models.cancel_trigger_models import CancelTriggerResponse import pytest @@ -37,6 +38,7 @@ def test_router_has_correct_routes(self): # Graph template routes (there are two /graph/{graph_name} routes - GET and PUT) assert any('/v0/namespace/{namespace_name}/graph/{graph_name}' in path for path in paths) + assert any('/v0/namespace/{namespace_name}/graph/{graph_name}/triggers' in path for path in paths) # Node registration routes assert any('/v0/namespace/{namespace_name}/nodes/' in path for path in paths) @@ -295,6 +297,36 @@ def test_manual_retry_response_model_validation(self): assert model.id == "507f1f77bcf86cd799439011" assert model.status == StateStatusEnum.CREATED + def test_cancel_trigger_response_model_validation(self): + """Test CancelTriggerResponse model validation""" + # Test with valid data + valid_data = { + "namespace": "test_namespace", + "graph_name": "test_graph", + "cancelled_count": 5, + "message": "Successfully cancelled 5 trigger(s)" + } + model = CancelTriggerResponse(**valid_data) + assert model.namespace == "test_namespace" + assert model.graph_name == "test_graph" + assert model.cancelled_count == 5 + assert model.message == "Successfully cancelled 5 trigger(s)" + + def test_cancel_trigger_response_model_with_zero_count(self): + """Test CancelTriggerResponse model with zero count""" + # Test with cancelled_count=0 (no triggers scenario) + valid_data = { + "namespace": "test_namespace", + "graph_name": "test_graph", + "cancelled_count": 0, + "message": "No pending triggers found to cancel" + } + model = CancelTriggerResponse(**valid_data) + assert model.namespace == "test_namespace" + assert model.graph_name == "test_graph" + assert model.cancelled_count == 0 + assert model.message == "No pending triggers found to cancel" + @@ -318,7 +350,8 @@ def test_route_handlers_exist(self): get_runs_route, get_graph_structure_route, get_node_run_details_route, - manual_retry_state_route + manual_retry_state_route, + cancel_triggers_route ) @@ -337,6 +370,7 @@ def test_route_handlers_exist(self): assert callable(get_graph_structure_route) assert callable(get_node_run_details_route) assert callable(manual_retry_state_route) + assert callable(cancel_triggers_route) @@ -1117,4 +1151,71 @@ async def test_manual_retry_state_route_without_request_id(self, mock_manual_ret assert call_args[0][2] == body # body # Should generate a UUID when no request ID is present assert len(call_args[0][3]) > 0 # x_exosphere_request_id should be generated - assert result == mock_manual_retry_state.return_value \ No newline at end of file + assert result == mock_manual_retry_state.return_value + + @patch('app.routes.cancel_triggers') + async def test_cancel_triggers_route_with_valid_api_key(self, mock_cancel_triggers, mock_request): + """Test cancel_triggers_route with valid API key""" + from app.routes import cancel_triggers_route + + # Arrange + expected_response = CancelTriggerResponse( + namespace="test_namespace", + graph_name="test_graph", + cancelled_count=3, + message="Successfully cancelled 3 trigger(s)" + ) + mock_cancel_triggers.return_value = expected_response + + # Act + result = await cancel_triggers_route("test_namespace", "test_graph", mock_request, "valid_key") + + # Assert + mock_cancel_triggers.assert_called_once_with("test_namespace", "test_graph", "test-request-id") + assert result == expected_response + assert result.namespace == "test_namespace" + assert result.graph_name == "test_graph" + assert result.cancelled_count == 3 + + @patch('app.routes.cancel_triggers') + async def test_cancel_triggers_route_with_invalid_api_key(self, mock_cancel_triggers, mock_request): + """Test cancel_triggers_route with invalid API key""" + from app.routes import cancel_triggers_route + from fastapi import HTTPException, status + + # Arrange + mock_cancel_triggers.return_value = MagicMock() + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await cancel_triggers_route("test_namespace", "test_graph", mock_request, None) # type: ignore + + assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED + assert exc_info.value.detail == "Invalid API key" + mock_cancel_triggers.assert_not_called() + + @patch('app.routes.cancel_triggers') + async def test_cancel_triggers_route_without_request_id(self, mock_cancel_triggers, mock_request_no_id): + """Test cancel_triggers_route without x_exosphere_request_id""" + from app.routes import cancel_triggers_route + + # Arrange + expected_response = CancelTriggerResponse( + namespace="test_namespace", + graph_name="test_graph", + cancelled_count=2, + message="Successfully cancelled 2 trigger(s)" + ) + mock_cancel_triggers.return_value = expected_response + + # Act + result = await cancel_triggers_route("test_namespace", "test_graph", mock_request_no_id, "valid_key") + + # Assert + mock_cancel_triggers.assert_called_once() + call_args = mock_cancel_triggers.call_args + assert call_args[0][0] == "test_namespace" + assert call_args[0][1] == "test_graph" + # Should generate a UUID when no request ID is present + assert len(call_args[0][2]) > 0 # x_exosphere_request_id should be generated + assert result == expected_response \ No newline at end of file