diff --git a/state-manager/app/config/settings.py b/state-manager/app/config/settings.py index 75cea36d..f5f4016f 100644 --- a/state-manager/app/config/settings.py +++ b/state-manager/app/config/settings.py @@ -13,6 +13,7 @@ class Settings(BaseModel): state_manager_secret: str = Field(..., description="Secret key for API authentication") secrets_encryption_key: str = Field(..., description="Key for encrypting secrets") trigger_workers: int = Field(default=1, description="Number of workers to run the trigger cron") + node_timeout_minutes: int = Field(default=60, gt=0, description="Timeout in minutes for nodes stuck in QUEUED status") trigger_retention_hours: int = Field(default=720, description="Number of hours to retain completed/failed triggers before cleanup") @classmethod @@ -22,6 +23,7 @@ def from_env(cls) -> "Settings": mongo_database_name=os.getenv("MONGO_DATABASE_NAME", "exosphere-state-manager"), # type: ignore state_manager_secret=os.getenv("STATE_MANAGER_SECRET"), # type: ignore secrets_encryption_key=os.getenv("SECRETS_ENCRYPTION_KEY"), # type: ignore + node_timeout_minutes=int(os.getenv("NODE_TIMEOUT_MINUTES", 60)), # type: ignore trigger_workers=int(os.getenv("TRIGGER_WORKERS", 1)), # type: ignore trigger_retention_hours=int(os.getenv("TRIGGER_RETENTION_HOURS", 720)) # type: ignore ) diff --git a/state-manager/app/controller/enqueue_states.py b/state-manager/app/controller/enqueue_states.py index a5c36b52..e9293c06 100644 --- a/state-manager/app/controller/enqueue_states.py +++ b/state-manager/app/controller/enqueue_states.py @@ -7,24 +7,47 @@ from ..models.state_status_enum import StateStatusEnum from app.singletons.logs_manager import LogsManager +from app.config.settings import get_settings from pymongo import ReturnDocument logger = LogsManager().get_logger() async def find_state(namespace_name: str, nodes: list[str]) -> State | None: + current_time_ms = int(time.time() * 1000) + settings = get_settings() + data = await State.get_pymongo_collection().find_one_and_update( { "namespace_name": namespace_name, "status": StateStatusEnum.CREATED, - "node_name": { - "$in": nodes - }, - "enqueue_after": {"$lte": int(time.time() * 1000)} - }, - { - "$set": {"status": StateStatusEnum.QUEUED} + "node_name": {"$in": nodes}, + "enqueue_after": {"$lte": current_time_ms} }, + [ + { + "$set": { + "status": StateStatusEnum.QUEUED, + "queued_at": current_time_ms, + "timeout_at": { + "$add": [ + current_time_ms, + { + "$multiply": [ + { + "$ifNull": [ + "$timeout_minutes", + settings.node_timeout_minutes + ] + }, + 60000 # Convert minutes to milliseconds + ] + } + ] + } + } + } + ], return_document=ReturnDocument.AFTER ) return State(**data) if data else None diff --git a/state-manager/app/controller/errored_state.py b/state-manager/app/controller/errored_state.py index e8eb5331..2be7dd76 100644 --- a/state-manager/app/controller/errored_state.py +++ b/state-manager/app/controller/errored_state.py @@ -53,7 +53,8 @@ async def errored_state(namespace_name: str, state_id: PydanticObjectId, body: E does_unites=state.does_unites, enqueue_after= int(time.time() * 1000) + graph_template.retry_policy.compute_delay(state.retry_count + 1), retry_count=state.retry_count + 1, - fanout_id=state.fanout_id + fanout_id=state.fanout_id, + timeout_at=state.timeout_at ) retry_state = await retry_state.insert() logger.info(f"Retry state {retry_state.id} created for state {state_id}", x_exosphere_request_id=x_exosphere_request_id) diff --git a/state-manager/app/controller/manual_retry_state.py b/state-manager/app/controller/manual_retry_state.py index b2ac3f47..b24de2d5 100644 --- a/state-manager/app/controller/manual_retry_state.py +++ b/state-manager/app/controller/manual_retry_state.py @@ -31,7 +31,8 @@ async def manual_retry_state(namespace_name: str, state_id: PydanticObjectId, bo parents=state.parents, does_unites=state.does_unites, fanout_id=body.fanout_id, # this will ensure that multiple unwanted retries are not formed because of index in database - manual_retry_fanout_id=body.fanout_id # This is included in the state fingerprint to allow unique manual retries of unite nodes. + manual_retry_fanout_id=body.fanout_id, # This is included in the state fingerprint to allow unique manual retries of unite nodes. + timeout_minutes=state.timeout_minutes ) retry_state = await retry_state.insert() logger.info(f"Retry state {retry_state.id} created for state {state_id}", x_exosphere_request_id=x_exosphere_request_id) diff --git a/state-manager/app/controller/re_queue_after_signal.py b/state-manager/app/controller/re_queue_after_signal.py index 009f1424..682be841 100644 --- a/state-manager/app/controller/re_queue_after_signal.py +++ b/state-manager/app/controller/re_queue_after_signal.py @@ -21,6 +21,8 @@ async def re_queue_after_signal(namespace_name: str, state_id: PydanticObjectId, state.status = StateStatusEnum.CREATED state.enqueue_after = int(time.time() * 1000) + body.enqueue_after + state.timeout_at = None + state.queued_at = None await state.save() return SignalResponseModel(status=StateStatusEnum.CREATED, enqueue_after=state.enqueue_after) diff --git a/state-manager/app/controller/register_nodes.py b/state-manager/app/controller/register_nodes.py index 2b820ec9..8fdd9b24 100644 --- a/state-manager/app/controller/register_nodes.py +++ b/state-manager/app/controller/register_nodes.py @@ -31,7 +31,8 @@ async def register_nodes(namespace_name: str, body: RegisterNodesRequestModel, x RegisteredNode.runtime_namespace: namespace_name, RegisteredNode.inputs_schema: node_data.inputs_schema, # type: ignore RegisteredNode.outputs_schema: node_data.outputs_schema, # type: ignore - RegisteredNode.secrets: node_data.secrets # type: ignore + RegisteredNode.secrets: node_data.secrets, # type: ignore + RegisteredNode.timeout_minutes: node_data.timeout_minutes # type: ignore })) logger.info(f"Updated existing node {node_data.name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) @@ -44,7 +45,8 @@ async def register_nodes(namespace_name: str, body: RegisterNodesRequestModel, x runtime_namespace=namespace_name, inputs_schema=node_data.inputs_schema, outputs_schema=node_data.outputs_schema, - secrets=node_data.secrets + secrets=node_data.secrets, + timeout_minutes=node_data.timeout_minutes ) await new_node.insert() logger.info(f"Created new node {node_data.name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) @@ -54,7 +56,8 @@ async def register_nodes(namespace_name: str, body: RegisterNodesRequestModel, x name=node_data.name, inputs_schema=node_data.inputs_schema, outputs_schema=node_data.outputs_schema, - secrets=node_data.secrets + secrets=node_data.secrets, + timeout_minutes=node_data.timeout_minutes ) ) diff --git a/state-manager/app/controller/trigger_graph.py b/state-manager/app/controller/trigger_graph.py index 46613e12..65cb4eba 100644 --- a/state-manager/app/controller/trigger_graph.py +++ b/state-manager/app/controller/trigger_graph.py @@ -7,8 +7,10 @@ from app.models.db.store import Store from app.models.db.run import Run from app.models.db.graph_template_model import GraphTemplate +from app.models.db.registered_node import RegisteredNode from app.models.node_template_model import NodeTemplate from app.models.dependent_string import DependentString +from app.config.settings import get_settings import uuid import time @@ -91,6 +93,16 @@ async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraph if len(new_stores) > 0: await Store.insert_many(new_stores) + # Get node timeout setting + registered_node = await RegisteredNode.get_by_name_and_namespace(root.node_name, root.namespace) + timeout_minutes = None + if registered_node and registered_node.timeout_minutes: + timeout_minutes = registered_node.timeout_minutes + else: + # Fall back to global setting + settings = get_settings() + timeout_minutes = settings.node_timeout_minutes + new_state = State( node_name=root.node_name, namespace_name=namespace_name, @@ -101,7 +113,8 @@ async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraph enqueue_after=int(time.time() * 1000) + body.start_delay, inputs=inputs, outputs={}, - error=None + error=None, + timeout_minutes=timeout_minutes ) await new_state.insert() diff --git a/state-manager/app/main.py b/state-manager/app/main.py index 0486a0c4..2004cdd7 100644 --- a/state-manager/app/main.py +++ b/state-manager/app/main.py @@ -38,6 +38,7 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger from .tasks.trigger_cron import trigger_cron +from .tasks.check_node_timeout import check_node_timeout # init tasks from .tasks.init_tasks import init_tasks @@ -83,6 +84,15 @@ async def lifespan(app: FastAPI): max_instances=1, id="every_minute_task" ) + scheduler.add_job( + check_node_timeout, + CronTrigger.from_crontab("* * * * *"), + replace_existing=True, + misfire_grace_time=60, + coalesce=True, + max_instances=1, + id="check_node_timeout_task" + ) scheduler.start() # main logic of the server diff --git a/state-manager/app/models/db/registered_node.py b/state-manager/app/models/db/registered_node.py index 9bc7c214..27fc97a2 100644 --- a/state-manager/app/models/db/registered_node.py +++ b/state-manager/app/models/db/registered_node.py @@ -1,6 +1,6 @@ from .base import BaseDatabaseModel from pydantic import Field -from typing import Any +from typing import Any, Optional from pymongo import IndexModel from ..node_template_model import NodeTemplate @@ -13,6 +13,7 @@ class RegisteredNode(BaseDatabaseModel): inputs_schema: dict[str, Any] = Field(..., description="JSON schema for node inputs") outputs_schema: dict[str, Any] = Field(..., description="JSON schema for node outputs") secrets: list[str] = Field(default_factory=list, description="List of secrets that the node uses") + timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this node. Falls back to global setting if not provided") class Settings: indexes = [ diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 6b9a8c74..942b258f 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -28,6 +28,9 @@ class State(BaseDatabaseModel): retry_count: int = Field(default=0, description="Number of times the state has been retried") fanout_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Fanout ID of the state") manual_retry_fanout_id: str = Field(default="", description="Fanout ID from a manual retry request, ensuring unique retries for unite nodes.") + queued_at: Optional[int] = Field(default=None, description="Unix time in milliseconds when state was queued") + timeout_at: Optional[int] = Field(default=None, description="Unix time in milliseconds when state times out") + timeout_minutes: Optional[int] = Field(default=None, gt=0, description="Timeout in minutes for this specific state, taken from node registration") @before_event([Insert, Replace, Save]) def _generate_fingerprint(self): @@ -102,5 +105,12 @@ class Settings: ("status", 1), ], name="run_id_status_index" + ), + IndexModel( + [ + ("status", 1), + ("timeout_at", 1), + ], + name="timeout_query_index" ) ] \ No newline at end of file diff --git a/state-manager/app/models/register_nodes_request.py b/state-manager/app/models/register_nodes_request.py index 38f83561..5723b20f 100644 --- a/state-manager/app/models/register_nodes_request.py +++ b/state-manager/app/models/register_nodes_request.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Any, List +from typing import Any, List, Optional class NodeRegistrationModel(BaseModel): @@ -7,6 +7,7 @@ class NodeRegistrationModel(BaseModel): inputs_schema: dict[str, Any] = Field(..., description="JSON schema for node inputs") outputs_schema: dict[str, Any] = Field(..., description="JSON schema for node outputs") secrets: List[str] = Field(..., description="List of secrets that the node uses") + timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this node. Falls back to global setting if not provided") class RegisterNodesRequestModel(BaseModel): diff --git a/state-manager/app/models/register_nodes_response.py b/state-manager/app/models/register_nodes_response.py index 991832a8..696b4643 100644 --- a/state-manager/app/models/register_nodes_response.py +++ b/state-manager/app/models/register_nodes_response.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Any, List +from typing import Any, List, Optional class RegisteredNodeModel(BaseModel): @@ -7,6 +7,7 @@ class RegisteredNodeModel(BaseModel): inputs_schema: dict[str, Any] = Field(..., description="Inputs for the registered node") outputs_schema: dict[str, Any] = Field(..., description="Outputs for the registered node") secrets: List[str] = Field(..., description="List of secrets that the node uses") + timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this node. Falls back to global setting if not provided") class RegisterNodesResponseModel(BaseModel): diff --git a/state-manager/app/models/state_status_enum.py b/state-manager/app/models/state_status_enum.py index cdbf563d..8eda7821 100644 --- a/state-manager/app/models/state_status_enum.py +++ b/state-manager/app/models/state_status_enum.py @@ -11,6 +11,7 @@ class StateStatusEnum(str, Enum): # Errored ERRORED = 'ERRORED' NEXT_CREATED_ERROR = 'NEXT_CREATED_ERROR' + TIMEDOUT = 'TIMEDOUT' # Success SUCCESS = 'SUCCESS' diff --git a/state-manager/app/tasks/check_node_timeout.py b/state-manager/app/tasks/check_node_timeout.py new file mode 100644 index 00000000..f4aa3216 --- /dev/null +++ b/state-manager/app/tasks/check_node_timeout.py @@ -0,0 +1,33 @@ +import time +from app.models.db.state import State +from app.models.state_status_enum import StateStatusEnum +from app.singletons.logs_manager import LogsManager + +logger = LogsManager().get_logger() + + +async def check_node_timeout(): + try: + current_time_ms = int(time.time() * 1000) + + logger.info(f"Checking for timed out nodes at {current_time_ms}") + + # Use database query to find and update timed out states in one operation + result = await State.get_pymongo_collection().update_many( + { + "status": StateStatusEnum.QUEUED, + "timeout_at": {"$ne": None, "$lte": current_time_ms} + }, + { + "$set": { + "status": StateStatusEnum.TIMEDOUT, + "error": "Node execution timed out" + } + } + ) + + if result.modified_count > 0: + logger.info(f"Marked {result.modified_count} states as TIMEDOUT") + + except Exception: + logger.error("Error checking node timeout", exc_info=True) diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index a5d86806..81541184 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -10,6 +10,7 @@ from app.models.db.store import Store from app.models.dependent_string import DependentString from app.models.node_template_model import UnitesStrategyEnum +from app.config.settings import get_settings from json_schema_to_pydantic import create_model from pydantic import BaseModel from typing import Type @@ -162,6 +163,16 @@ async def generate_next_state(next_state_input_model: Type[BaseModel], next_stat current_state.identifier: current_state.id } + # Get timeout for this node + registered_node = await get_registered_node(next_state_node_template) + timeout_minutes = None + if registered_node.timeout_minutes: + timeout_minutes = registered_node.timeout_minutes + else: + # Fall back to global setting + settings = get_settings() + timeout_minutes = settings.node_timeout_minutes + return State( node_name=next_state_node_template.node_name, identifier=next_state_node_template.identifier, @@ -173,7 +184,8 @@ async def generate_next_state(next_state_input_model: Type[BaseModel], next_stat outputs={}, does_unites=next_state_node_template.unites is not None, run_id=current_state.run_id, - error=None + error=None, + timeout_minutes=timeout_minutes ) current_states = await State.find( diff --git a/state-manager/tests/unit/tasks/test_check_node_timeout.py b/state-manager/tests/unit/tasks/test_check_node_timeout.py new file mode 100644 index 00000000..35ef81f7 --- /dev/null +++ b/state-manager/tests/unit/tasks/test_check_node_timeout.py @@ -0,0 +1,152 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from app.tasks.check_node_timeout import check_node_timeout +from app.models.state_status_enum import StateStatusEnum + + +@pytest.mark.asyncio +async def test_check_node_timeout_with_different_timeouts(): + """Test that nodes with different timeout_minutes are handled correctly""" + + # Mock current time (Unix timestamp in seconds: 1700000000 = Nov 14, 2023) + current_time_ms = 1700000000 * 1000 # Convert to milliseconds + + # Create mock states with different timeouts + state1 = MagicMock() + state1.status = StateStatusEnum.QUEUED + state1.queued_at = current_time_ms - (35 * 60 * 1000) # 35 minutes ago + state1.timeout_minutes = 30 # Should timeout (35 > 30) + + state2 = MagicMock() + state2.status = StateStatusEnum.QUEUED + state2.queued_at = current_time_ms - (25 * 60 * 1000) # 25 minutes ago + state2.timeout_minutes = 30 # Should NOT timeout (25 < 30) + + state3 = MagicMock() + state3.status = StateStatusEnum.QUEUED + state3.queued_at = current_time_ms - (45 * 60 * 1000) # 45 minutes ago + state3.timeout_minutes = None # Use global setting (30 min) + + mock_states = [state1, state2, state3] + + # Mock settings + mock_settings = MagicMock() + mock_settings.node_timeout_minutes = 30 # Global setting + + with patch('app.tasks.check_node_timeout.State') as mock_state_class, \ + patch('app.tasks.check_node_timeout.get_settings', return_value=mock_settings), \ + patch('app.tasks.check_node_timeout.time.time', return_value=current_time_ms / 1000): + + # Mock State.find().to_list() to return our mock states + mock_state_class.find.return_value.to_list = AsyncMock(return_value=mock_states) + mock_state_class.save_all = AsyncMock() + + await check_node_timeout() + + # Verify state1 and state3 were marked as TIMEDOUT, but not state2 + assert state1.status == StateStatusEnum.TIMEDOUT + assert state1.error == "Node execution timed out after 30 minutes" + + assert state2.status == StateStatusEnum.QUEUED # Should remain QUEUED + + assert state3.status == StateStatusEnum.TIMEDOUT + assert state3.error == "Node execution timed out after 30 minutes" + + # Verify save_all was called with the 2 timed out states + mock_state_class.save_all.assert_called_once() + saved_states = mock_state_class.save_all.call_args[0][0] + assert len(saved_states) == 2 + assert state1 in saved_states + assert state3 in saved_states + + +@pytest.mark.asyncio +async def test_check_node_timeout_no_timeouts(): + """Test that no states are marked as timed out when none exceed their timeout""" + + current_time_ms = 1700000000 * 1000 + + # Create mock state that hasn't timed out + state1 = MagicMock() + state1.status = StateStatusEnum.QUEUED + state1.queued_at = current_time_ms - (10 * 60 * 1000) # 10 minutes ago + state1.timeout_minutes = 30 # Should NOT timeout + + mock_states = [state1] + mock_settings = MagicMock() + mock_settings.node_timeout_minutes = 30 + + with patch('app.tasks.check_node_timeout.State') as mock_state_class, \ + patch('app.tasks.check_node_timeout.get_settings', return_value=mock_settings), \ + patch('app.tasks.check_node_timeout.time.time', return_value=current_time_ms / 1000): + + mock_state_class.find.return_value.to_list = AsyncMock(return_value=mock_states) + mock_state_class.save_all = AsyncMock() + + await check_node_timeout() + + # Verify state remains QUEUED + assert state1.status == StateStatusEnum.QUEUED + + # Verify save_all was not called since no states timed out + mock_state_class.save_all.assert_not_called() + + +@pytest.mark.asyncio +async def test_check_node_timeout_handles_exception(): + """Test that exceptions in check_node_timeout are logged properly""" + + with patch('app.tasks.check_node_timeout.State') as mock_state_class, \ + patch('app.tasks.check_node_timeout.logger') as mock_logger: + + # Mock State.find to raise an exception + mock_state_class.find.side_effect = Exception("Database error") + + await check_node_timeout() + + # Verify error was logged with exc_info + mock_logger.error.assert_called_once_with("Error checking node timeout", exc_info=True) + + +@pytest.mark.asyncio +async def test_check_node_timeout_custom_node_timeout(): + """Test that nodes with custom timeout_minutes use their own timeout value""" + + current_time_ms = 1700000000 * 1000 + + # Create mock state with custom timeout + state1 = MagicMock() + state1.status = StateStatusEnum.QUEUED + state1.queued_at = current_time_ms - (35 * 60 * 1000) # 35 minutes ago + state1.timeout_minutes = 60 # Custom timeout of 60 minutes - should NOT timeout + + # Create mock state with global timeout + state2 = MagicMock() + state2.status = StateStatusEnum.QUEUED + state2.queued_at = current_time_ms - (35 * 60 * 1000) # 35 minutes ago + state2.timeout_minutes = None # Use global setting (30 min) - should timeout + + mock_states = [state1, state2] + mock_settings = MagicMock() + mock_settings.node_timeout_minutes = 30 # Global setting + + with patch('app.tasks.check_node_timeout.State') as mock_state_class, \ + patch('app.tasks.check_node_timeout.get_settings', return_value=mock_settings), \ + patch('app.tasks.check_node_timeout.time.time', return_value=current_time_ms / 1000): + + mock_state_class.find.return_value.to_list = AsyncMock(return_value=mock_states) + mock_state_class.save_all = AsyncMock() + + await check_node_timeout() + + # Verify only state2 was marked as TIMEDOUT (using global 30 min timeout) + assert state1.status == StateStatusEnum.QUEUED # Custom 60 min timeout + + assert state2.status == StateStatusEnum.TIMEDOUT # Global 30 min timeout + assert state2.error == "Node execution timed out after 30 minutes" + + # Verify save_all was called with only state2 + mock_state_class.save_all.assert_called_once() + saved_states = mock_state_class.save_all.call_args[0][0] + assert len(saved_states) == 1 + assert state2 in saved_states \ No newline at end of file