From 735cd6dbd6e14ff300db1f536f02ea23d09c38a3 Mon Sep 17 00:00:00 2001 From: Agampreet Singh Date: Fri, 3 Oct 2025 11:12:55 +0530 Subject: [PATCH 1/7] feat: add node-level timeouts to prevent stuck queued states --- state-manager/app/config/settings.py | 4 +- .../app/controller/enqueue_states.py | 5 +- state-manager/app/main.py | 10 ++ state-manager/app/models/db/state.py | 8 ++ state-manager/app/models/state_status_enum.py | 1 + state-manager/app/tasks/check_node_timeout.py | 36 ++++++ .../unit/tasks/test_check_node_timeout.py | 114 ++++++++++++++++++ 7 files changed, 176 insertions(+), 2 deletions(-) create mode 100644 state-manager/app/tasks/check_node_timeout.py create mode 100644 state-manager/tests/unit/tasks/test_check_node_timeout.py diff --git a/state-manager/app/config/settings.py b/state-manager/app/config/settings.py index 5d75fc2b..3f173abc 100644 --- a/state-manager/app/config/settings.py +++ b/state-manager/app/config/settings.py @@ -13,6 +13,7 @@ class Settings(BaseModel): state_manager_secret: str = Field(..., description="Secret key for API authentication") secrets_encryption_key: str = Field(..., description="Key for encrypting secrets") trigger_workers: int = Field(default=1, description="Number of workers to run the trigger cron") + node_timeout_minutes: int = Field(default=30, description="Timeout in minutes for nodes stuck in QUEUED status") @classmethod def from_env(cls) -> "Settings": @@ -21,7 +22,8 @@ def from_env(cls) -> "Settings": mongo_database_name=os.getenv("MONGO_DATABASE_NAME", "exosphere-state-manager"), # type: ignore state_manager_secret=os.getenv("STATE_MANAGER_SECRET"), # type: ignore secrets_encryption_key=os.getenv("SECRETS_ENCRYPTION_KEY"), # type: ignore - trigger_workers=int(os.getenv("TRIGGER_WORKERS", 1)) # type: ignore + trigger_workers=int(os.getenv("TRIGGER_WORKERS", 1)), # type: ignore + node_timeout_minutes=int(os.getenv("NODE_TIMEOUT_MINUTES", 30)) # type: ignore ) diff --git a/state-manager/app/controller/enqueue_states.py b/state-manager/app/controller/enqueue_states.py index a5c36b52..8721cc46 100644 --- a/state-manager/app/controller/enqueue_states.py +++ b/state-manager/app/controller/enqueue_states.py @@ -23,7 +23,10 @@ async def find_state(namespace_name: str, nodes: list[str]) -> State | None: "enqueue_after": {"$lte": int(time.time() * 1000)} }, { - "$set": {"status": StateStatusEnum.QUEUED} + "$set": { + "status": StateStatusEnum.QUEUED, + "queued_at": int(time.time() * 1000) + } }, return_document=ReturnDocument.AFTER ) diff --git a/state-manager/app/main.py b/state-manager/app/main.py index b9e5d6ad..4ad2d979 100644 --- a/state-manager/app/main.py +++ b/state-manager/app/main.py @@ -38,6 +38,7 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger from .tasks.trigger_cron import trigger_cron +from .tasks.check_node_timeout import check_node_timeout # Define models list DOCUMENT_MODELS = [State, GraphTemplate, RegisteredNode, Store, Run, DatabaseTriggers] @@ -76,6 +77,15 @@ async def lifespan(app: FastAPI): max_instances=1, id="every_minute_task" ) + scheduler.add_job( + check_node_timeout, + CronTrigger.from_crontab("* * * * *"), + replace_existing=True, + misfire_grace_time=60, + coalesce=True, + max_instances=1, + id="check_node_timeout_task" + ) scheduler.start() # main logic of the server diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 6b9a8c74..4790b43f 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -28,6 +28,7 @@ class State(BaseDatabaseModel): retry_count: int = Field(default=0, description="Number of times the state has been retried") fanout_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Fanout ID of the state") manual_retry_fanout_id: str = Field(default="", description="Fanout ID from a manual retry request, ensuring unique retries for unite nodes.") + queued_at: Optional[int] = Field(None, description="Unix time in milliseconds when the state was queued") @before_event([Insert, Replace, Save]) def _generate_fingerprint(self): @@ -102,5 +103,12 @@ class Settings: ("status", 1), ], name="run_id_status_index" + ), + IndexModel( + [ + ("status", 1), + ("queued_at", 1), + ], + name="timeout_query_index" ) ] \ No newline at end of file diff --git a/state-manager/app/models/state_status_enum.py b/state-manager/app/models/state_status_enum.py index cdbf563d..8eda7821 100644 --- a/state-manager/app/models/state_status_enum.py +++ b/state-manager/app/models/state_status_enum.py @@ -11,6 +11,7 @@ class StateStatusEnum(str, Enum): # Errored ERRORED = 'ERRORED' NEXT_CREATED_ERROR = 'NEXT_CREATED_ERROR' + TIMEDOUT = 'TIMEDOUT' # Success SUCCESS = 'SUCCESS' diff --git a/state-manager/app/tasks/check_node_timeout.py b/state-manager/app/tasks/check_node_timeout.py new file mode 100644 index 00000000..db11f42f --- /dev/null +++ b/state-manager/app/tasks/check_node_timeout.py @@ -0,0 +1,36 @@ +import time +from app.models.db.state import State +from app.models.state_status_enum import StateStatusEnum +from app.singletons.logs_manager import LogsManager +from app.config.settings import get_settings + +logger = LogsManager().get_logger() + + +async def check_node_timeout(): + try: + settings = get_settings() + timeout_ms = settings.node_timeout_minutes * 60 * 1000 + current_time_ms = int(time.time() * 1000) + timeout_threshold = current_time_ms - timeout_ms + + logger.info(f"Checking for timed out nodes with threshold: {timeout_threshold}") + + result = await State.get_pymongo_collection().update_many( + { + "status": StateStatusEnum.QUEUED, + "queued_at": {"$ne": None, "$lte": timeout_threshold} + }, + { + "$set": { + "status": StateStatusEnum.TIMEDOUT, + "error": f"Node execution timed out after {settings.node_timeout_minutes} minutes" + } + } + ) + + if result.modified_count > 0: + logger.info(f"Marked {result.modified_count} states as TIMEDOUT") + + except Exception as e: + logger.error(f"Error checking node timeout: {e}") diff --git a/state-manager/tests/unit/tasks/test_check_node_timeout.py b/state-manager/tests/unit/tasks/test_check_node_timeout.py new file mode 100644 index 00000000..887cfd59 --- /dev/null +++ b/state-manager/tests/unit/tasks/test_check_node_timeout.py @@ -0,0 +1,114 @@ +import pytest +import time +from unittest.mock import AsyncMock, MagicMock, patch +from app.models.state_status_enum import StateStatusEnum + + +class TestCheckNodeTimeout: + + @pytest.mark.asyncio + async def test_check_node_timeout_marks_timed_out_states(self): + mock_collection = MagicMock() + mock_result = MagicMock() + mock_result.modified_count = 3 + mock_collection.update_many = AsyncMock(return_value=mock_result) + + with patch('app.tasks.check_node_timeout.State') as mock_state, \ + patch('app.tasks.check_node_timeout.get_settings') as mock_get_settings: + + from app.tasks.check_node_timeout import check_node_timeout + + mock_settings = MagicMock() + mock_settings.node_timeout_minutes = 30 + mock_get_settings.return_value = mock_settings + + mock_state.get_pymongo_collection.return_value = mock_collection + + await check_node_timeout() + + mock_collection.update_many.assert_called_once() + call_args = mock_collection.update_many.call_args + + query = call_args[0][0] + update = call_args[0][1] + + assert query["status"] == StateStatusEnum.QUEUED + assert "$ne" in query["queued_at"] + assert "$lte" in query["queued_at"] + + assert update["$set"]["status"] == StateStatusEnum.TIMEDOUT + assert "timed out after 30 minutes" in update["$set"]["error"] + + @pytest.mark.asyncio + async def test_check_node_timeout_no_timed_out_states(self): + mock_collection = MagicMock() + mock_result = MagicMock() + mock_result.modified_count = 0 + mock_collection.update_many = AsyncMock(return_value=mock_result) + + with patch('app.tasks.check_node_timeout.State') as mock_state, \ + patch('app.tasks.check_node_timeout.get_settings') as mock_get_settings: + + from app.tasks.check_node_timeout import check_node_timeout + + mock_settings = MagicMock() + mock_settings.node_timeout_minutes = 30 + mock_get_settings.return_value = mock_settings + + mock_state.get_pymongo_collection.return_value = mock_collection + + await check_node_timeout() + + mock_collection.update_many.assert_called_once() + + @pytest.mark.asyncio + async def test_check_node_timeout_handles_exception(self): + mock_collection = MagicMock() + mock_collection.update_many = AsyncMock(side_effect=Exception("Database error")) + + with patch('app.tasks.check_node_timeout.State') as mock_state, \ + patch('app.tasks.check_node_timeout.get_settings') as mock_get_settings, \ + patch('app.tasks.check_node_timeout.logger') as mock_logger: + + from app.tasks.check_node_timeout import check_node_timeout + + mock_settings = MagicMock() + mock_settings.node_timeout_minutes = 30 + mock_get_settings.return_value = mock_settings + + mock_state.get_pymongo_collection.return_value = mock_collection + + await check_node_timeout() + + mock_logger.error.assert_called_once() + error_message = mock_logger.error.call_args[0][0] + assert "Error checking node timeout" in error_message + + @pytest.mark.asyncio + async def test_check_node_timeout_calculates_correct_threshold(self): + mock_collection = MagicMock() + mock_result = MagicMock() + mock_result.modified_count = 0 + mock_collection.update_many = AsyncMock(return_value=mock_result) + + with patch('app.tasks.check_node_timeout.State') as mock_state, \ + patch('app.tasks.check_node_timeout.get_settings') as mock_get_settings, \ + patch('app.tasks.check_node_timeout.time') as mock_time: + + from app.tasks.check_node_timeout import check_node_timeout + + mock_time.time.return_value = 1000 + + mock_settings = MagicMock() + mock_settings.node_timeout_minutes = 45 + mock_get_settings.return_value = mock_settings + + mock_state.get_pymongo_collection.return_value = mock_collection + + await check_node_timeout() + + call_args = mock_collection.update_many.call_args + query = call_args[0][0] + + expected_threshold = (1000 * 1000) - (45 * 60 * 1000) + assert query["queued_at"]["$lte"] == expected_threshold From 4303a3deff0e4f5888d46037902cb5750ec4d558 Mon Sep 17 00:00:00 2001 From: Agampreet Singh Date: Fri, 3 Oct 2025 11:27:33 +0530 Subject: [PATCH 2/7] improved & fixed issues on timeout feature --- state-manager/app/config/settings.py | 4 ++-- state-manager/app/controller/enqueue_states.py | 5 +++-- state-manager/app/tasks/check_node_timeout.py | 4 ++-- state-manager/tests/unit/tasks/test_check_node_timeout.py | 4 ++-- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/state-manager/app/config/settings.py b/state-manager/app/config/settings.py index 3f173abc..7252af69 100644 --- a/state-manager/app/config/settings.py +++ b/state-manager/app/config/settings.py @@ -13,7 +13,7 @@ class Settings(BaseModel): state_manager_secret: str = Field(..., description="Secret key for API authentication") secrets_encryption_key: str = Field(..., description="Key for encrypting secrets") trigger_workers: int = Field(default=1, description="Number of workers to run the trigger cron") - node_timeout_minutes: int = Field(default=30, description="Timeout in minutes for nodes stuck in QUEUED status") + node_timeout_minutes: int = Field(default=30, gt=0, description="Timeout in minutes for nodes stuck in QUEUED status") @classmethod def from_env(cls) -> "Settings": @@ -23,7 +23,7 @@ def from_env(cls) -> "Settings": state_manager_secret=os.getenv("STATE_MANAGER_SECRET"), # type: ignore secrets_encryption_key=os.getenv("SECRETS_ENCRYPTION_KEY"), # type: ignore trigger_workers=int(os.getenv("TRIGGER_WORKERS", 1)), # type: ignore - node_timeout_minutes=int(os.getenv("NODE_TIMEOUT_MINUTES", 30)) # type: ignore + node_timeout_minutes=os.getenv("NODE_TIMEOUT_MINUTES", "30") # type: ignore ) diff --git a/state-manager/app/controller/enqueue_states.py b/state-manager/app/controller/enqueue_states.py index 8721cc46..22689e49 100644 --- a/state-manager/app/controller/enqueue_states.py +++ b/state-manager/app/controller/enqueue_states.py @@ -13,6 +13,7 @@ async def find_state(namespace_name: str, nodes: list[str]) -> State | None: + current_time_ms = int(time.time() * 1000) data = await State.get_pymongo_collection().find_one_and_update( { "namespace_name": namespace_name, @@ -20,12 +21,12 @@ async def find_state(namespace_name: str, nodes: list[str]) -> State | None: "node_name": { "$in": nodes }, - "enqueue_after": {"$lte": int(time.time() * 1000)} + "enqueue_after": {"$lte": current_time_ms} }, { "$set": { "status": StateStatusEnum.QUEUED, - "queued_at": int(time.time() * 1000) + "queued_at": current_time_ms } }, return_document=ReturnDocument.AFTER diff --git a/state-manager/app/tasks/check_node_timeout.py b/state-manager/app/tasks/check_node_timeout.py index db11f42f..48d9b1f8 100644 --- a/state-manager/app/tasks/check_node_timeout.py +++ b/state-manager/app/tasks/check_node_timeout.py @@ -32,5 +32,5 @@ async def check_node_timeout(): if result.modified_count > 0: logger.info(f"Marked {result.modified_count} states as TIMEDOUT") - except Exception as e: - logger.error(f"Error checking node timeout: {e}") + except Exception: + logger.error("Error checking node timeout", exc_info=True) diff --git a/state-manager/tests/unit/tasks/test_check_node_timeout.py b/state-manager/tests/unit/tasks/test_check_node_timeout.py index 887cfd59..97dc2f5a 100644 --- a/state-manager/tests/unit/tasks/test_check_node_timeout.py +++ b/state-manager/tests/unit/tasks/test_check_node_timeout.py @@ -97,7 +97,7 @@ async def test_check_node_timeout_calculates_correct_threshold(self): from app.tasks.check_node_timeout import check_node_timeout - mock_time.time.return_value = 1000 + mock_time.time.return_value = 1700000000 mock_settings = MagicMock() mock_settings.node_timeout_minutes = 45 @@ -110,5 +110,5 @@ async def test_check_node_timeout_calculates_correct_threshold(self): call_args = mock_collection.update_many.call_args query = call_args[0][0] - expected_threshold = (1000 * 1000) - (45 * 60 * 1000) + expected_threshold = (1700000000 * 1000) - (45 * 60 * 1000) assert query["queued_at"]["$lte"] == expected_threshold From 8634281956317afe00ad2dbf11f8a26a7215e89d Mon Sep 17 00:00:00 2001 From: Agampreet Singh Date: Fri, 3 Oct 2025 18:36:07 +0530 Subject: [PATCH 3/7] feat: implement node-level timeouts for stuck queued states --- state-manager/app/config/settings.py | 2 +- .../app/controller/register_nodes.py | 9 +- state-manager/app/controller/trigger_graph.py | 15 +- .../app/models/db/registered_node.py | 3 +- state-manager/app/models/db/state.py | 2 + .../app/models/register_nodes_request.py | 3 +- .../app/models/register_nodes_response.py | 3 +- state-manager/app/tasks/check_node_timeout.py | 40 +-- state-manager/app/tasks/create_next_states.py | 14 +- .../unit/tasks/test_check_node_timeout.py | 258 ++++++++++-------- 10 files changed, 213 insertions(+), 136 deletions(-) diff --git a/state-manager/app/config/settings.py b/state-manager/app/config/settings.py index 7252af69..76347d21 100644 --- a/state-manager/app/config/settings.py +++ b/state-manager/app/config/settings.py @@ -22,7 +22,7 @@ def from_env(cls) -> "Settings": mongo_database_name=os.getenv("MONGO_DATABASE_NAME", "exosphere-state-manager"), # type: ignore state_manager_secret=os.getenv("STATE_MANAGER_SECRET"), # type: ignore secrets_encryption_key=os.getenv("SECRETS_ENCRYPTION_KEY"), # type: ignore - trigger_workers=int(os.getenv("TRIGGER_WORKERS", 1)), # type: ignore + trigger_workers=os.getenv("TRIGGER_WORKERS", "1"), # type: ignore node_timeout_minutes=os.getenv("NODE_TIMEOUT_MINUTES", "30") # type: ignore ) diff --git a/state-manager/app/controller/register_nodes.py b/state-manager/app/controller/register_nodes.py index 2b820ec9..8fdd9b24 100644 --- a/state-manager/app/controller/register_nodes.py +++ b/state-manager/app/controller/register_nodes.py @@ -31,7 +31,8 @@ async def register_nodes(namespace_name: str, body: RegisterNodesRequestModel, x RegisteredNode.runtime_namespace: namespace_name, RegisteredNode.inputs_schema: node_data.inputs_schema, # type: ignore RegisteredNode.outputs_schema: node_data.outputs_schema, # type: ignore - RegisteredNode.secrets: node_data.secrets # type: ignore + RegisteredNode.secrets: node_data.secrets, # type: ignore + RegisteredNode.timeout_minutes: node_data.timeout_minutes # type: ignore })) logger.info(f"Updated existing node {node_data.name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) @@ -44,7 +45,8 @@ async def register_nodes(namespace_name: str, body: RegisterNodesRequestModel, x runtime_namespace=namespace_name, inputs_schema=node_data.inputs_schema, outputs_schema=node_data.outputs_schema, - secrets=node_data.secrets + secrets=node_data.secrets, + timeout_minutes=node_data.timeout_minutes ) await new_node.insert() logger.info(f"Created new node {node_data.name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) @@ -54,7 +56,8 @@ async def register_nodes(namespace_name: str, body: RegisterNodesRequestModel, x name=node_data.name, inputs_schema=node_data.inputs_schema, outputs_schema=node_data.outputs_schema, - secrets=node_data.secrets + secrets=node_data.secrets, + timeout_minutes=node_data.timeout_minutes ) ) diff --git a/state-manager/app/controller/trigger_graph.py b/state-manager/app/controller/trigger_graph.py index 46613e12..65cb4eba 100644 --- a/state-manager/app/controller/trigger_graph.py +++ b/state-manager/app/controller/trigger_graph.py @@ -7,8 +7,10 @@ from app.models.db.store import Store from app.models.db.run import Run from app.models.db.graph_template_model import GraphTemplate +from app.models.db.registered_node import RegisteredNode from app.models.node_template_model import NodeTemplate from app.models.dependent_string import DependentString +from app.config.settings import get_settings import uuid import time @@ -91,6 +93,16 @@ async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraph if len(new_stores) > 0: await Store.insert_many(new_stores) + # Get node timeout setting + registered_node = await RegisteredNode.get_by_name_and_namespace(root.node_name, root.namespace) + timeout_minutes = None + if registered_node and registered_node.timeout_minutes: + timeout_minutes = registered_node.timeout_minutes + else: + # Fall back to global setting + settings = get_settings() + timeout_minutes = settings.node_timeout_minutes + new_state = State( node_name=root.node_name, namespace_name=namespace_name, @@ -101,7 +113,8 @@ async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraph enqueue_after=int(time.time() * 1000) + body.start_delay, inputs=inputs, outputs={}, - error=None + error=None, + timeout_minutes=timeout_minutes ) await new_state.insert() diff --git a/state-manager/app/models/db/registered_node.py b/state-manager/app/models/db/registered_node.py index 9bc7c214..27fc97a2 100644 --- a/state-manager/app/models/db/registered_node.py +++ b/state-manager/app/models/db/registered_node.py @@ -1,6 +1,6 @@ from .base import BaseDatabaseModel from pydantic import Field -from typing import Any +from typing import Any, Optional from pymongo import IndexModel from ..node_template_model import NodeTemplate @@ -13,6 +13,7 @@ class RegisteredNode(BaseDatabaseModel): inputs_schema: dict[str, Any] = Field(..., description="JSON schema for node inputs") outputs_schema: dict[str, Any] = Field(..., description="JSON schema for node outputs") secrets: list[str] = Field(default_factory=list, description="List of secrets that the node uses") + timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this node. Falls back to global setting if not provided") class Settings: indexes = [ diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 4790b43f..1b3ece7c 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -28,6 +28,8 @@ class State(BaseDatabaseModel): retry_count: int = Field(default=0, description="Number of times the state has been retried") fanout_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Fanout ID of the state") manual_retry_fanout_id: str = Field(default="", description="Fanout ID from a manual retry request, ensuring unique retries for unite nodes.") + queued_at: Optional[int] = Field(None, description="Unix time in milliseconds when state was queued") + timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this specific state, taken from node registration") queued_at: Optional[int] = Field(None, description="Unix time in milliseconds when the state was queued") @before_event([Insert, Replace, Save]) diff --git a/state-manager/app/models/register_nodes_request.py b/state-manager/app/models/register_nodes_request.py index 38f83561..5723b20f 100644 --- a/state-manager/app/models/register_nodes_request.py +++ b/state-manager/app/models/register_nodes_request.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Any, List +from typing import Any, List, Optional class NodeRegistrationModel(BaseModel): @@ -7,6 +7,7 @@ class NodeRegistrationModel(BaseModel): inputs_schema: dict[str, Any] = Field(..., description="JSON schema for node inputs") outputs_schema: dict[str, Any] = Field(..., description="JSON schema for node outputs") secrets: List[str] = Field(..., description="List of secrets that the node uses") + timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this node. Falls back to global setting if not provided") class RegisterNodesRequestModel(BaseModel): diff --git a/state-manager/app/models/register_nodes_response.py b/state-manager/app/models/register_nodes_response.py index 991832a8..696b4643 100644 --- a/state-manager/app/models/register_nodes_response.py +++ b/state-manager/app/models/register_nodes_response.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Any, List +from typing import Any, List, Optional class RegisteredNodeModel(BaseModel): @@ -7,6 +7,7 @@ class RegisteredNodeModel(BaseModel): inputs_schema: dict[str, Any] = Field(..., description="Inputs for the registered node") outputs_schema: dict[str, Any] = Field(..., description="Outputs for the registered node") secrets: List[str] = Field(..., description="List of secrets that the node uses") + timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this node. Falls back to global setting if not provided") class RegisterNodesResponseModel(BaseModel): diff --git a/state-manager/app/tasks/check_node_timeout.py b/state-manager/app/tasks/check_node_timeout.py index 48d9b1f8..4ccdbaab 100644 --- a/state-manager/app/tasks/check_node_timeout.py +++ b/state-manager/app/tasks/check_node_timeout.py @@ -10,27 +10,33 @@ async def check_node_timeout(): try: settings = get_settings() - timeout_ms = settings.node_timeout_minutes * 60 * 1000 current_time_ms = int(time.time() * 1000) - timeout_threshold = current_time_ms - timeout_ms - logger.info(f"Checking for timed out nodes with threshold: {timeout_threshold}") + logger.info(f"Checking for timed out nodes at {current_time_ms}") - result = await State.get_pymongo_collection().update_many( - { - "status": StateStatusEnum.QUEUED, - "queued_at": {"$ne": None, "$lte": timeout_threshold} - }, - { - "$set": { - "status": StateStatusEnum.TIMEDOUT, - "error": f"Node execution timed out after {settings.node_timeout_minutes} minutes" - } - } - ) + # Find all QUEUED states with queued_at set + queued_states = await State.find( + State.status == StateStatusEnum.QUEUED, + State.queued_at != None + ).to_list() - if result.modified_count > 0: - logger.info(f"Marked {result.modified_count} states as TIMEDOUT") + states_to_timeout = [] + + for state in queued_states: + # Use state-specific timeout if available, otherwise fall back to global + timeout_minutes = state.timeout_minutes if state.timeout_minutes else settings.node_timeout_minutes + timeout_ms = timeout_minutes * 60 * 1000 + timeout_threshold = current_time_ms - timeout_ms + + if state.queued_at <= timeout_threshold: + state.status = StateStatusEnum.TIMEDOUT + state.error = f"Node execution timed out after {timeout_minutes} minutes" + states_to_timeout.append(state) + + if states_to_timeout: + # Update all timed out states in bulk + await State.save_all(states_to_timeout) + logger.info(f"Marked {len(states_to_timeout)} states as TIMEDOUT") except Exception: logger.error("Error checking node timeout", exc_info=True) diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index a5d86806..81541184 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -10,6 +10,7 @@ from app.models.db.store import Store from app.models.dependent_string import DependentString from app.models.node_template_model import UnitesStrategyEnum +from app.config.settings import get_settings from json_schema_to_pydantic import create_model from pydantic import BaseModel from typing import Type @@ -162,6 +163,16 @@ async def generate_next_state(next_state_input_model: Type[BaseModel], next_stat current_state.identifier: current_state.id } + # Get timeout for this node + registered_node = await get_registered_node(next_state_node_template) + timeout_minutes = None + if registered_node.timeout_minutes: + timeout_minutes = registered_node.timeout_minutes + else: + # Fall back to global setting + settings = get_settings() + timeout_minutes = settings.node_timeout_minutes + return State( node_name=next_state_node_template.node_name, identifier=next_state_node_template.identifier, @@ -173,7 +184,8 @@ async def generate_next_state(next_state_input_model: Type[BaseModel], next_stat outputs={}, does_unites=next_state_node_template.unites is not None, run_id=current_state.run_id, - error=None + error=None, + timeout_minutes=timeout_minutes ) current_states = await State.find( diff --git a/state-manager/tests/unit/tasks/test_check_node_timeout.py b/state-manager/tests/unit/tasks/test_check_node_timeout.py index 97dc2f5a..35ef81f7 100644 --- a/state-manager/tests/unit/tasks/test_check_node_timeout.py +++ b/state-manager/tests/unit/tasks/test_check_node_timeout.py @@ -1,114 +1,152 @@ import pytest -import time -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch, MagicMock +from app.tasks.check_node_timeout import check_node_timeout from app.models.state_status_enum import StateStatusEnum -class TestCheckNodeTimeout: - - @pytest.mark.asyncio - async def test_check_node_timeout_marks_timed_out_states(self): - mock_collection = MagicMock() - mock_result = MagicMock() - mock_result.modified_count = 3 - mock_collection.update_many = AsyncMock(return_value=mock_result) - - with patch('app.tasks.check_node_timeout.State') as mock_state, \ - patch('app.tasks.check_node_timeout.get_settings') as mock_get_settings: - - from app.tasks.check_node_timeout import check_node_timeout - - mock_settings = MagicMock() - mock_settings.node_timeout_minutes = 30 - mock_get_settings.return_value = mock_settings - - mock_state.get_pymongo_collection.return_value = mock_collection - - await check_node_timeout() - - mock_collection.update_many.assert_called_once() - call_args = mock_collection.update_many.call_args - - query = call_args[0][0] - update = call_args[0][1] - - assert query["status"] == StateStatusEnum.QUEUED - assert "$ne" in query["queued_at"] - assert "$lte" in query["queued_at"] - - assert update["$set"]["status"] == StateStatusEnum.TIMEDOUT - assert "timed out after 30 minutes" in update["$set"]["error"] - - @pytest.mark.asyncio - async def test_check_node_timeout_no_timed_out_states(self): - mock_collection = MagicMock() - mock_result = MagicMock() - mock_result.modified_count = 0 - mock_collection.update_many = AsyncMock(return_value=mock_result) - - with patch('app.tasks.check_node_timeout.State') as mock_state, \ - patch('app.tasks.check_node_timeout.get_settings') as mock_get_settings: - - from app.tasks.check_node_timeout import check_node_timeout - - mock_settings = MagicMock() - mock_settings.node_timeout_minutes = 30 - mock_get_settings.return_value = mock_settings - - mock_state.get_pymongo_collection.return_value = mock_collection - - await check_node_timeout() - - mock_collection.update_many.assert_called_once() - - @pytest.mark.asyncio - async def test_check_node_timeout_handles_exception(self): - mock_collection = MagicMock() - mock_collection.update_many = AsyncMock(side_effect=Exception("Database error")) - - with patch('app.tasks.check_node_timeout.State') as mock_state, \ - patch('app.tasks.check_node_timeout.get_settings') as mock_get_settings, \ - patch('app.tasks.check_node_timeout.logger') as mock_logger: - - from app.tasks.check_node_timeout import check_node_timeout - - mock_settings = MagicMock() - mock_settings.node_timeout_minutes = 30 - mock_get_settings.return_value = mock_settings - - mock_state.get_pymongo_collection.return_value = mock_collection - - await check_node_timeout() - - mock_logger.error.assert_called_once() - error_message = mock_logger.error.call_args[0][0] - assert "Error checking node timeout" in error_message - - @pytest.mark.asyncio - async def test_check_node_timeout_calculates_correct_threshold(self): - mock_collection = MagicMock() - mock_result = MagicMock() - mock_result.modified_count = 0 - mock_collection.update_many = AsyncMock(return_value=mock_result) - - with patch('app.tasks.check_node_timeout.State') as mock_state, \ - patch('app.tasks.check_node_timeout.get_settings') as mock_get_settings, \ - patch('app.tasks.check_node_timeout.time') as mock_time: - - from app.tasks.check_node_timeout import check_node_timeout - - mock_time.time.return_value = 1700000000 - - mock_settings = MagicMock() - mock_settings.node_timeout_minutes = 45 - mock_get_settings.return_value = mock_settings - - mock_state.get_pymongo_collection.return_value = mock_collection - - await check_node_timeout() - - call_args = mock_collection.update_many.call_args - query = call_args[0][0] - - expected_threshold = (1700000000 * 1000) - (45 * 60 * 1000) - assert query["queued_at"]["$lte"] == expected_threshold +@pytest.mark.asyncio +async def test_check_node_timeout_with_different_timeouts(): + """Test that nodes with different timeout_minutes are handled correctly""" + + # Mock current time (Unix timestamp in seconds: 1700000000 = Nov 14, 2023) + current_time_ms = 1700000000 * 1000 # Convert to milliseconds + + # Create mock states with different timeouts + state1 = MagicMock() + state1.status = StateStatusEnum.QUEUED + state1.queued_at = current_time_ms - (35 * 60 * 1000) # 35 minutes ago + state1.timeout_minutes = 30 # Should timeout (35 > 30) + + state2 = MagicMock() + state2.status = StateStatusEnum.QUEUED + state2.queued_at = current_time_ms - (25 * 60 * 1000) # 25 minutes ago + state2.timeout_minutes = 30 # Should NOT timeout (25 < 30) + + state3 = MagicMock() + state3.status = StateStatusEnum.QUEUED + state3.queued_at = current_time_ms - (45 * 60 * 1000) # 45 minutes ago + state3.timeout_minutes = None # Use global setting (30 min) + + mock_states = [state1, state2, state3] + + # Mock settings + mock_settings = MagicMock() + mock_settings.node_timeout_minutes = 30 # Global setting + + with patch('app.tasks.check_node_timeout.State') as mock_state_class, \ + patch('app.tasks.check_node_timeout.get_settings', return_value=mock_settings), \ + patch('app.tasks.check_node_timeout.time.time', return_value=current_time_ms / 1000): + + # Mock State.find().to_list() to return our mock states + mock_state_class.find.return_value.to_list = AsyncMock(return_value=mock_states) + mock_state_class.save_all = AsyncMock() + + await check_node_timeout() + + # Verify state1 and state3 were marked as TIMEDOUT, but not state2 + assert state1.status == StateStatusEnum.TIMEDOUT + assert state1.error == "Node execution timed out after 30 minutes" + + assert state2.status == StateStatusEnum.QUEUED # Should remain QUEUED + + assert state3.status == StateStatusEnum.TIMEDOUT + assert state3.error == "Node execution timed out after 30 minutes" + + # Verify save_all was called with the 2 timed out states + mock_state_class.save_all.assert_called_once() + saved_states = mock_state_class.save_all.call_args[0][0] + assert len(saved_states) == 2 + assert state1 in saved_states + assert state3 in saved_states + + +@pytest.mark.asyncio +async def test_check_node_timeout_no_timeouts(): + """Test that no states are marked as timed out when none exceed their timeout""" + + current_time_ms = 1700000000 * 1000 + + # Create mock state that hasn't timed out + state1 = MagicMock() + state1.status = StateStatusEnum.QUEUED + state1.queued_at = current_time_ms - (10 * 60 * 1000) # 10 minutes ago + state1.timeout_minutes = 30 # Should NOT timeout + + mock_states = [state1] + mock_settings = MagicMock() + mock_settings.node_timeout_minutes = 30 + + with patch('app.tasks.check_node_timeout.State') as mock_state_class, \ + patch('app.tasks.check_node_timeout.get_settings', return_value=mock_settings), \ + patch('app.tasks.check_node_timeout.time.time', return_value=current_time_ms / 1000): + + mock_state_class.find.return_value.to_list = AsyncMock(return_value=mock_states) + mock_state_class.save_all = AsyncMock() + + await check_node_timeout() + + # Verify state remains QUEUED + assert state1.status == StateStatusEnum.QUEUED + + # Verify save_all was not called since no states timed out + mock_state_class.save_all.assert_not_called() + + +@pytest.mark.asyncio +async def test_check_node_timeout_handles_exception(): + """Test that exceptions in check_node_timeout are logged properly""" + + with patch('app.tasks.check_node_timeout.State') as mock_state_class, \ + patch('app.tasks.check_node_timeout.logger') as mock_logger: + + # Mock State.find to raise an exception + mock_state_class.find.side_effect = Exception("Database error") + + await check_node_timeout() + + # Verify error was logged with exc_info + mock_logger.error.assert_called_once_with("Error checking node timeout", exc_info=True) + + +@pytest.mark.asyncio +async def test_check_node_timeout_custom_node_timeout(): + """Test that nodes with custom timeout_minutes use their own timeout value""" + + current_time_ms = 1700000000 * 1000 + + # Create mock state with custom timeout + state1 = MagicMock() + state1.status = StateStatusEnum.QUEUED + state1.queued_at = current_time_ms - (35 * 60 * 1000) # 35 minutes ago + state1.timeout_minutes = 60 # Custom timeout of 60 minutes - should NOT timeout + + # Create mock state with global timeout + state2 = MagicMock() + state2.status = StateStatusEnum.QUEUED + state2.queued_at = current_time_ms - (35 * 60 * 1000) # 35 minutes ago + state2.timeout_minutes = None # Use global setting (30 min) - should timeout + + mock_states = [state1, state2] + mock_settings = MagicMock() + mock_settings.node_timeout_minutes = 30 # Global setting + + with patch('app.tasks.check_node_timeout.State') as mock_state_class, \ + patch('app.tasks.check_node_timeout.get_settings', return_value=mock_settings), \ + patch('app.tasks.check_node_timeout.time.time', return_value=current_time_ms / 1000): + + mock_state_class.find.return_value.to_list = AsyncMock(return_value=mock_states) + mock_state_class.save_all = AsyncMock() + + await check_node_timeout() + + # Verify only state2 was marked as TIMEDOUT (using global 30 min timeout) + assert state1.status == StateStatusEnum.QUEUED # Custom 60 min timeout + + assert state2.status == StateStatusEnum.TIMEDOUT # Global 30 min timeout + assert state2.error == "Node execution timed out after 30 minutes" + + # Verify save_all was called with only state2 + mock_state_class.save_all.assert_called_once() + saved_states = mock_state_class.save_all.call_args[0][0] + assert len(saved_states) == 1 + assert state2 in saved_states \ No newline at end of file From 4ee52fd6f49fb9701f2ee06187a10635d3d4e983 Mon Sep 17 00:00:00 2001 From: Agampreet Singh Date: Fri, 3 Oct 2025 18:47:12 +0530 Subject: [PATCH 4/7] minor fixes --- state-manager/app/models/db/state.py | 1 - state-manager/app/tasks/check_node_timeout.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 1b3ece7c..9c1b3124 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -30,7 +30,6 @@ class State(BaseDatabaseModel): manual_retry_fanout_id: str = Field(default="", description="Fanout ID from a manual retry request, ensuring unique retries for unite nodes.") queued_at: Optional[int] = Field(None, description="Unix time in milliseconds when state was queued") timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this specific state, taken from node registration") - queued_at: Optional[int] = Field(None, description="Unix time in milliseconds when the state was queued") @before_event([Insert, Replace, Save]) def _generate_fingerprint(self): diff --git a/state-manager/app/tasks/check_node_timeout.py b/state-manager/app/tasks/check_node_timeout.py index 4ccdbaab..6d13756a 100644 --- a/state-manager/app/tasks/check_node_timeout.py +++ b/state-manager/app/tasks/check_node_timeout.py @@ -1,4 +1,5 @@ import time +from beanie.operators import Ne from app.models.db.state import State from app.models.state_status_enum import StateStatusEnum from app.singletons.logs_manager import LogsManager @@ -17,7 +18,7 @@ async def check_node_timeout(): # Find all QUEUED states with queued_at set queued_states = await State.find( State.status == StateStatusEnum.QUEUED, - State.queued_at != None + Ne(State.queued_at, None) ).to_list() states_to_timeout = [] From 4be13c9aa8182ec5b053af07a06146f3d1fdcda4 Mon Sep 17 00:00:00 2001 From: Agampreet Singh Date: Fri, 10 Oct 2025 01:32:13 +0530 Subject: [PATCH 5/7] minor fixes --- .../app/controller/enqueue_states.py | 69 ++++++++++++++++--- state-manager/app/models/db/state.py | 3 +- state-manager/app/tasks/check_node_timeout.py | 40 ++++------- 3 files changed, 78 insertions(+), 34 deletions(-) diff --git a/state-manager/app/controller/enqueue_states.py b/state-manager/app/controller/enqueue_states.py index 22689e49..11172bd2 100644 --- a/state-manager/app/controller/enqueue_states.py +++ b/state-manager/app/controller/enqueue_states.py @@ -7,6 +7,7 @@ from ..models.state_status_enum import StateStatusEnum from app.singletons.logs_manager import LogsManager +from app.config.settings import get_settings from pymongo import ReturnDocument logger = LogsManager().get_logger() @@ -14,21 +15,73 @@ async def find_state(namespace_name: str, nodes: list[str]) -> State | None: current_time_ms = int(time.time() * 1000) + settings = get_settings() + + # Use pipeline to calculate timeout_at based on state-specific or global timeout + pipeline = [ + { + "$match": { + "namespace_name": namespace_name, + "status": StateStatusEnum.CREATED, + "node_name": {"$in": nodes}, + "enqueue_after": {"$lte": current_time_ms} + } + }, + { + "$addFields": { + "status": StateStatusEnum.QUEUED, + "queued_at": current_time_ms, + "timeout_at": { + "$add": [ + current_time_ms, + { + "$multiply": [ + { + "$ifNull": [ + "$timeout_minutes", + settings.node_timeout_minutes + ] + }, + 60000 # Convert minutes to milliseconds + ] + } + ] + } + } + } + ] + data = await State.get_pymongo_collection().find_one_and_update( { "namespace_name": namespace_name, "status": StateStatusEnum.CREATED, - "node_name": { - "$in": nodes - }, + "node_name": {"$in": nodes}, "enqueue_after": {"$lte": current_time_ms} }, - { - "$set": { - "status": StateStatusEnum.QUEUED, - "queued_at": current_time_ms + [ + { + "$set": { + "status": StateStatusEnum.QUEUED, + "queued_at": current_time_ms, + "timeout_at": { + "$add": [ + current_time_ms, + { + "$multiply": [ + { + "$ifNull": [ + "$timeout_minutes", + settings.node_timeout_minutes + ] + }, + 60000 # Convert minutes to milliseconds + ] + } + ] + } + } } - }, + ], return_document=ReturnDocument.AFTER ) return State(**data) if data else None diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 9c1b3124..3837de2a 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -29,6 +29,7 @@ class State(BaseDatabaseModel): fanout_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Fanout ID of the state") manual_retry_fanout_id: str = Field(default="", description="Fanout ID from a manual retry request, ensuring unique retries for unite nodes.") queued_at: Optional[int] = Field(None, description="Unix time in milliseconds when state was queued") + timeout_at: Optional[int] = Field(None, description="Unix time in milliseconds when state times out") timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this specific state, taken from node registration") @before_event([Insert, Replace, Save]) @@ -108,7 +109,7 @@ class Settings: IndexModel( [ ("status", 1), - ("queued_at", 1), + ("timeout_at", 1), ], name="timeout_query_index" ) diff --git a/state-manager/app/tasks/check_node_timeout.py b/state-manager/app/tasks/check_node_timeout.py index 6d13756a..f4aa3216 100644 --- a/state-manager/app/tasks/check_node_timeout.py +++ b/state-manager/app/tasks/check_node_timeout.py @@ -1,43 +1,33 @@ import time -from beanie.operators import Ne from app.models.db.state import State from app.models.state_status_enum import StateStatusEnum from app.singletons.logs_manager import LogsManager -from app.config.settings import get_settings logger = LogsManager().get_logger() async def check_node_timeout(): try: - settings = get_settings() current_time_ms = int(time.time() * 1000) logger.info(f"Checking for timed out nodes at {current_time_ms}") - # Find all QUEUED states with queued_at set - queued_states = await State.find( - State.status == StateStatusEnum.QUEUED, - Ne(State.queued_at, None) - ).to_list() + # Use database query to find and update timed out states in one operation + result = await State.get_pymongo_collection().update_many( + { + "status": StateStatusEnum.QUEUED, + "timeout_at": {"$ne": None, "$lte": current_time_ms} + }, + { + "$set": { + "status": StateStatusEnum.TIMEDOUT, + "error": "Node execution timed out" + } + } + ) - states_to_timeout = [] - - for state in queued_states: - # Use state-specific timeout if available, otherwise fall back to global - timeout_minutes = state.timeout_minutes if state.timeout_minutes else settings.node_timeout_minutes - timeout_ms = timeout_minutes * 60 * 1000 - timeout_threshold = current_time_ms - timeout_ms - - if state.queued_at <= timeout_threshold: - state.status = StateStatusEnum.TIMEDOUT - state.error = f"Node execution timed out after {timeout_minutes} minutes" - states_to_timeout.append(state) - - if states_to_timeout: - # Update all timed out states in bulk - await State.save_all(states_to_timeout) - logger.info(f"Marked {len(states_to_timeout)} states as TIMEDOUT") + if result.modified_count > 0: + logger.info(f"Marked {result.modified_count} states as TIMEDOUT") except Exception: logger.error("Error checking node timeout", exc_info=True) From de475b80802d271191bbdb87e6901e91f69a9f33 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Thu, 23 Oct 2025 12:12:47 +0530 Subject: [PATCH 6/7] fix: increase node timeout from 30 to 60 minutes in settings.py Updated the default value of node_timeout_minutes to enhance the timeout duration for nodes stuck in QUEUED status, improving system resilience and performance. --- state-manager/app/config/settings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/state-manager/app/config/settings.py b/state-manager/app/config/settings.py index 41a5e7d4..f5f4016f 100644 --- a/state-manager/app/config/settings.py +++ b/state-manager/app/config/settings.py @@ -13,7 +13,7 @@ class Settings(BaseModel): state_manager_secret: str = Field(..., description="Secret key for API authentication") secrets_encryption_key: str = Field(..., description="Key for encrypting secrets") trigger_workers: int = Field(default=1, description="Number of workers to run the trigger cron") - node_timeout_minutes: int = Field(default=30, gt=0, description="Timeout in minutes for nodes stuck in QUEUED status") + node_timeout_minutes: int = Field(default=60, gt=0, description="Timeout in minutes for nodes stuck in QUEUED status") trigger_retention_hours: int = Field(default=720, description="Number of hours to retain completed/failed triggers before cleanup") @classmethod @@ -23,7 +23,7 @@ def from_env(cls) -> "Settings": mongo_database_name=os.getenv("MONGO_DATABASE_NAME", "exosphere-state-manager"), # type: ignore state_manager_secret=os.getenv("STATE_MANAGER_SECRET"), # type: ignore secrets_encryption_key=os.getenv("SECRETS_ENCRYPTION_KEY"), # type: ignore - node_timeout_minutes=os.getenv("NODE_TIMEOUT_MINUTES", "30") # type: ignore + node_timeout_minutes=int(os.getenv("NODE_TIMEOUT_MINUTES", 60)), # type: ignore trigger_workers=int(os.getenv("TRIGGER_WORKERS", 1)), # type: ignore trigger_retention_hours=int(os.getenv("TRIGGER_RETENTION_HOURS", 720)) # type: ignore ) From 2c7f3c1326c63802b0d668bb8a47668b42dfafd8 Mon Sep 17 00:00:00 2001 From: NiveditJain Date: Thu, 23 Oct 2025 12:59:49 +0530 Subject: [PATCH 7/7] refactor: streamline state management by removing unused pipeline and enhancing state attributes - Removed the unused pipeline for calculating timeout_at in find_state function. - Added timeout_at and timeout_minutes attributes to the state model for better state management. - Updated errored_state and manual_retry_state functions to include timeout_at in state creation. - Set timeout_at and queued_at to None in re_queue_after_signal function to reset state attributes upon re-queuing. --- .../app/controller/enqueue_states.py | 34 ------------------- state-manager/app/controller/errored_state.py | 3 +- .../app/controller/manual_retry_state.py | 3 +- .../app/controller/re_queue_after_signal.py | 2 ++ state-manager/app/models/db/state.py | 6 ++-- 5 files changed, 9 insertions(+), 39 deletions(-) diff --git a/state-manager/app/controller/enqueue_states.py b/state-manager/app/controller/enqueue_states.py index 11172bd2..e9293c06 100644 --- a/state-manager/app/controller/enqueue_states.py +++ b/state-manager/app/controller/enqueue_states.py @@ -17,40 +17,6 @@ async def find_state(namespace_name: str, nodes: list[str]) -> State | None: current_time_ms = int(time.time() * 1000) settings = get_settings() - # Use pipeline to calculate timeout_at based on state-specific or global timeout - pipeline = [ - { - "$match": { - "namespace_name": namespace_name, - "status": StateStatusEnum.CREATED, - "node_name": {"$in": nodes}, - "enqueue_after": {"$lte": current_time_ms} - } - }, - { - "$addFields": { - "status": StateStatusEnum.QUEUED, - "queued_at": current_time_ms, - "timeout_at": { - "$add": [ - current_time_ms, - { - "$multiply": [ - { - "$ifNull": [ - "$timeout_minutes", - settings.node_timeout_minutes - ] - }, - 60000 # Convert minutes to milliseconds - ] - } - ] - } - } - } - ] - data = await State.get_pymongo_collection().find_one_and_update( { "namespace_name": namespace_name, diff --git a/state-manager/app/controller/errored_state.py b/state-manager/app/controller/errored_state.py index e8eb5331..2be7dd76 100644 --- a/state-manager/app/controller/errored_state.py +++ b/state-manager/app/controller/errored_state.py @@ -53,7 +53,8 @@ async def errored_state(namespace_name: str, state_id: PydanticObjectId, body: E does_unites=state.does_unites, enqueue_after= int(time.time() * 1000) + graph_template.retry_policy.compute_delay(state.retry_count + 1), retry_count=state.retry_count + 1, - fanout_id=state.fanout_id + fanout_id=state.fanout_id, + timeout_at=state.timeout_at ) retry_state = await retry_state.insert() logger.info(f"Retry state {retry_state.id} created for state {state_id}", x_exosphere_request_id=x_exosphere_request_id) diff --git a/state-manager/app/controller/manual_retry_state.py b/state-manager/app/controller/manual_retry_state.py index b2ac3f47..b24de2d5 100644 --- a/state-manager/app/controller/manual_retry_state.py +++ b/state-manager/app/controller/manual_retry_state.py @@ -31,7 +31,8 @@ async def manual_retry_state(namespace_name: str, state_id: PydanticObjectId, bo parents=state.parents, does_unites=state.does_unites, fanout_id=body.fanout_id, # this will ensure that multiple unwanted retries are not formed because of index in database - manual_retry_fanout_id=body.fanout_id # This is included in the state fingerprint to allow unique manual retries of unite nodes. + manual_retry_fanout_id=body.fanout_id, # This is included in the state fingerprint to allow unique manual retries of unite nodes. + timeout_minutes=state.timeout_minutes ) retry_state = await retry_state.insert() logger.info(f"Retry state {retry_state.id} created for state {state_id}", x_exosphere_request_id=x_exosphere_request_id) diff --git a/state-manager/app/controller/re_queue_after_signal.py b/state-manager/app/controller/re_queue_after_signal.py index 009f1424..682be841 100644 --- a/state-manager/app/controller/re_queue_after_signal.py +++ b/state-manager/app/controller/re_queue_after_signal.py @@ -21,6 +21,8 @@ async def re_queue_after_signal(namespace_name: str, state_id: PydanticObjectId, state.status = StateStatusEnum.CREATED state.enqueue_after = int(time.time() * 1000) + body.enqueue_after + state.timeout_at = None + state.queued_at = None await state.save() return SignalResponseModel(status=StateStatusEnum.CREATED, enqueue_after=state.enqueue_after) diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 3837de2a..942b258f 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -28,9 +28,9 @@ class State(BaseDatabaseModel): retry_count: int = Field(default=0, description="Number of times the state has been retried") fanout_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Fanout ID of the state") manual_retry_fanout_id: str = Field(default="", description="Fanout ID from a manual retry request, ensuring unique retries for unite nodes.") - queued_at: Optional[int] = Field(None, description="Unix time in milliseconds when state was queued") - timeout_at: Optional[int] = Field(None, description="Unix time in milliseconds when state times out") - timeout_minutes: Optional[int] = Field(None, gt=0, description="Timeout in minutes for this specific state, taken from node registration") + queued_at: Optional[int] = Field(default=None, description="Unix time in milliseconds when state was queued") + timeout_at: Optional[int] = Field(default=None, description="Unix time in milliseconds when state times out") + timeout_minutes: Optional[int] = Field(default=None, gt=0, description="Timeout in minutes for this specific state, taken from node registration") @before_event([Insert, Replace, Save]) def _generate_fingerprint(self):