diff --git a/state-manager/app/controller/errored_state.py b/state-manager/app/controller/errored_state.py index e8eb5331..14097616 100644 --- a/state-manager/app/controller/errored_state.py +++ b/state-manager/app/controller/errored_state.py @@ -9,11 +9,14 @@ from app.models.state_status_enum import StateStatusEnum from app.singletons.logs_manager import LogsManager from app.models.db.graph_template_model import GraphTemplate +from app.tasks.webhook import dispatch_webhook +from datetime import datetime +from fastapi import BackgroundTasks logger = LogsManager().get_logger() - -async def errored_state(namespace_name: str, state_id: PydanticObjectId, body: ErroredRequestModel, x_exosphere_request_id: str) -> ErroredResponseModel: - +async def errored_state(namespace_name: str, state_id: PydanticObjectId, body: ErroredRequestModel, x_exosphere_request_id: str, background_tasks: BackgroundTasks | None = None,) -> ErroredResponseModel: + if background_tasks is None: + background_tasks = BackgroundTasks() try: logger.info(f"Errored state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) @@ -70,6 +73,26 @@ async def errored_state(namespace_name: str, state_id: PydanticObjectId, body: E state.error = body.error await state.save() + if ( + not retry_created + and graph_template.webhook + and "GRAPH_FAILED" in graph_template.webhook.events + ): + background_tasks.add_task( + dispatch_webhook, + url=graph_template.webhook.url, + payload={ + "event": "GRAPH_FAILED", + "namespace": namespace_name, + "graph_name": state.graph_name, + "run_id": state.run_id, + "failed_state_id": str(state.id), + "node_name": state.node_name, + "error": body.error, + "timestamp": datetime.utcnow().isoformat(), + }, + headers=graph_template.webhook.headers, + ) return ErroredResponseModel(status=StateStatusEnum.ERRORED, retry_created=retry_created) except Exception as e: diff --git a/state-manager/app/controller/executed_state.py b/state-manager/app/controller/executed_state.py index 27baabd7..bda3da81 100644 --- a/state-manager/app/controller/executed_state.py +++ b/state-manager/app/controller/executed_state.py @@ -1,8 +1,7 @@ from beanie import PydanticObjectId -from app.models.executed_models import ExecutedRequestModel, ExecutedResponseModel - from fastapi import HTTPException, status, BackgroundTasks +from app.models.executed_models import ExecutedRequestModel, ExecutedResponseModel from app.models.db.state import State from app.models.state_status_enum import StateStatusEnum from app.singletons.logs_manager import LogsManager @@ -10,19 +9,36 @@ logger = LogsManager().get_logger() -async def executed_state(namespace_name: str, state_id: PydanticObjectId, body: ExecutedRequestModel, x_exosphere_request_id: str, background_tasks: BackgroundTasks) -> ExecutedResponseModel: +async def executed_state( + namespace_name: str, + state_id: PydanticObjectId, + body: ExecutedRequestModel, + x_exosphere_request_id: str, + background_tasks: BackgroundTasks, +) -> ExecutedResponseModel: try: - logger.info(f"Executed state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + logger.info( + f"Executed state {state_id} for namespace {namespace_name}", + x_exosphere_request_id=x_exosphere_request_id, + ) state = await State.find_one(State.id == state_id) if not state or not state.id: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="State not found", + ) if state.status != StateStatusEnum.QUEUED: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not queued") - - next_state_ids = [] + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="State is not queued", + ) + + next_state_ids: list[PydanticObjectId] = [] + + # ---- Handle outputs ---- if len(body.outputs) == 0: state.status = StateStatusEnum.EXECUTED state.outputs = {} @@ -30,35 +46,56 @@ async def executed_state(namespace_name: str, state_id: PydanticObjectId, body: next_state_ids.append(state.id) - else: + else: + # First output updates the current state state.outputs = body.outputs[0] state.status = StateStatusEnum.EXECUTED await state.save() + next_state_ids.append(state.id) + # Remaining outputs create new states new_states = [] for output in body.outputs[1:]: - new_states.append(State( - node_name=state.node_name, - namespace_name=state.namespace_name, - identifier=state.identifier, - graph_name=state.graph_name, - run_id=state.run_id, - status=StateStatusEnum.EXECUTED, - inputs=state.inputs, - outputs=output, - error=None, - parents=state.parents - )) - - if len(new_states) > 0: - inserted_ids = (await State.insert_many(new_states)).inserted_ids + new_states.append( + State( + node_name=state.node_name, + namespace_name=state.namespace_name, + identifier=state.identifier, + graph_name=state.graph_name, + run_id=state.run_id, + status=StateStatusEnum.EXECUTED, + inputs=state.inputs, + outputs=output, + error=None, + parents=state.parents, + ) + ) + + if new_states: + inserted_ids = ( + await State.insert_many(new_states) + ).inserted_ids next_state_ids.extend(inserted_ids) - background_tasks.add_task(create_next_states, next_state_ids, state.identifier, state.namespace_name, state.graph_name, state.parents) + # ---- Create next states ---- + background_tasks.add_task( + create_next_states, + next_state_ids, + state.identifier, + state.namespace_name, + state.graph_name, + state.parents, + ) - return ExecutedResponseModel(status=StateStatusEnum.EXECUTED) + return ExecutedResponseModel( + status=StateStatusEnum.EXECUTED + ) except Exception as e: - logger.error(f"Error executing state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id, error=e) - raise e + logger.error( + f"Error executing state {state_id} for namespace {namespace_name}", + x_exosphere_request_id=x_exosphere_request_id, + error=e, + ) + raise diff --git a/state-manager/app/models/db/graph_template_model.py b/state-manager/app/models/db/graph_template_model.py index 9693981d..2e8554f3 100644 --- a/state-manager/app/models/db/graph_template_model.py +++ b/state-manager/app/models/db/graph_template_model.py @@ -14,6 +14,7 @@ from app.models.retry_policy_model import RetryPolicyModel from app.models.store_config_model import StoreConfig from app.models.trigger_models import Trigger +from app.models.webhook_config_model import WebhookConfig class GraphTemplate(BaseDatabaseModel): name: str = Field(..., description="Name of the graph") @@ -25,6 +26,7 @@ class GraphTemplate(BaseDatabaseModel): triggers: List[Trigger] = Field(default_factory=list, description="Triggers of the graph") retry_policy: RetryPolicyModel = Field(default_factory=RetryPolicyModel, description="Retry policy of the graph") store_config: StoreConfig = Field(default_factory=StoreConfig, description="Store config of the graph") + webhook: WebhookConfig | None = Field(default=None, description="Optional webhook configuration for graph execution events") _node_by_identifier: Dict[str, NodeTemplate] | None = PrivateAttr(default=None) _parents_by_identifier: Dict[str, set[str]] | None = PrivateAttr(default=None) # type: ignore @@ -318,7 +320,7 @@ def get_path_by_identifier(self, identifier: str) -> set[str]: @staticmethod async def get(namespace: str, graph_name: str) -> "GraphTemplate": - graph_template = await GraphTemplate.find_one(GraphTemplate.namespace == namespace, GraphTemplate.name == graph_name) + graph_template = await GraphTemplate.find_one(GraphTemplate.namespace == namespace,GraphTemplate.name == graph_name) if not graph_template: raise ValueError(f"Graph template not found for namespace: {namespace} and graph name: {graph_name}") return graph_template diff --git a/state-manager/app/models/webhook_config_model.py b/state-manager/app/models/webhook_config_model.py new file mode 100644 index 00000000..69e0ddbf --- /dev/null +++ b/state-manager/app/models/webhook_config_model.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel, Field +from typing import List, Dict, Optional + + +class WebhookConfig(BaseModel): + url: str = Field(..., description="Webhook endpoint URL") + events: List[str] = Field(default_factory=list, description="Subscribed events") + headers: Optional[Dict[str, str]] = Field( + default=None, + description="Optional HTTP headers for webhook requests" + ) diff --git a/state-manager/app/tasks/webhook.py b/state-manager/app/tasks/webhook.py new file mode 100644 index 00000000..502393df --- /dev/null +++ b/state-manager/app/tasks/webhook.py @@ -0,0 +1,32 @@ +import logging +from datetime import datetime +from typing import Optional + +import httpx + +logger = logging.getLogger(__name__) + + +async def dispatch_webhook( + *, + url: str, + payload: dict, + headers: Optional[dict] = None, +) -> None: + """ + Dispatch a webhook event. + This must never raise exceptions (best-effort delivery). + """ + try: + async with httpx.AsyncClient(timeout=5) as client: + await client.post( + url, + json=payload, + headers=headers or {}, + ) + except Exception as exc: + logger.warning( + "Webhook dispatch failed", + exc_info=exc, + extra={"url": url}, + )