diff --git a/packages/polaris/polaris/polaris/__init__.py b/packages/polaris/polaris/polaris/__init__.py index eb680553..c8e79c8d 100644 --- a/packages/polaris/polaris/polaris/__init__.py +++ b/packages/polaris/polaris/polaris/__init__.py @@ -1,5 +1,5 @@ from .modules.registry import Registry from .modules.runner import Runner -from .runtime import run +from .runtime import initialize, is_initialized, run -__all__ = ["run", "Registry", "Runner"] +__all__ = ["initialize", "is_initialized", "run", "Registry", "Runner"] diff --git a/packages/polaris/polaris/polaris/core/client.py b/packages/polaris/polaris/polaris/core/client.py index 1973dd65..e05c2a20 100644 --- a/packages/polaris/polaris/polaris/core/client.py +++ b/packages/polaris/polaris/polaris/core/client.py @@ -1,6 +1,13 @@ +"""HTTP client with automatic retry and environment detection. + +Provides adaptive HTTP client that works in both browser (Pyodide) and +server environments with built-in retry logic for transient failures. +""" + import asyncio import json import logging +from typing import Any, Awaitable, Callable, TypeVar from .exceptions import HttpError @@ -11,13 +18,81 @@ MAX_RETRIES = 3 INITIAL_BACKOFF = 1.0 # seconds +T = TypeVar("T") + + +async def _retry_request( + request_fn: Callable[[], Awaitable[tuple[int, str, T | None]]], + url: str, + method: str, +) -> T: + """Execute HTTP request with retry logic for transient failures. + + Args: + request_fn: Async function that returns (status_code, response_text, parsed_data). + parsed_data is None if request failed. + url: Request URL (for error context) + method: HTTP method (for error context) + + Returns: + Parsed response data on success + + Raises: + HttpError: On non-retryable error or after all retries exhausted + """ + last_error: HttpError | None = None + + for attempt in range(MAX_RETRIES): + status, text, data = await request_fn() + + # Success + if data is not None: + return data + + # Non-retryable error + if status not in RETRY_STATUS_CODES: + raise HttpError( + f"HTTP {status}: {text}", + status_code=status, + details={"url": url, "method": method}, + ) + + # Retryable error - store and possibly retry + last_error = HttpError( + f"HTTP {status}: {text}", + status_code=status, + details={"url": url, "method": method}, + ) + + if attempt < MAX_RETRIES - 1: + backoff = INITIAL_BACKOFF * (2**attempt) + logger.warning( + f"HTTP {status}, retrying in {backoff}s (attempt {attempt + 1}/{MAX_RETRIES})" + ) + await asyncio.sleep(backoff) + + # All retries exhausted + if last_error is None: + # This should never happen since MAX_RETRIES >= 1 + raise HttpError( + "Request failed with no error captured", + status_code=0, + details={"url": url, "method": method}, + ) + raise last_error + class HttpClient: - async def request(self, method, url, headers=None, body=None): + """Base HTTP client interface.""" + + async def request( + self, method: str, url: str, headers: dict[str, str] | None = None, body: Any = None + ) -> Any: raise NotImplementedError -def is_pyodide(): +def is_pyodide() -> bool: + """Check if running in Pyodide (browser) environment.""" try: import pyodide_js # type: ignore[import-not-found] # noqa: F401 @@ -26,8 +101,8 @@ def is_pyodide(): return False -# parse response without relying on content type -async def parse_response(response): +async def _parse_response(response: Any) -> Any: + """Parse response without relying on content type.""" text = await response.text() try: return json.loads(text) @@ -41,16 +116,20 @@ async def parse_response(response): class BrowserHttpClient(HttpClient): - def __init__(self): + """HTTP client for browser/Pyodide environment using fetch API.""" + + def __init__(self) -> None: from js import fetch # type: ignore[import-not-found] from pyodide.ffi import to_js # type: ignore[import-not-found] self._fetch = fetch self._to_js = to_js - async def request(self, method, url, headers=None, body=None): + async def request( + self, method: str, url: str, headers: dict[str, str] | None = None, body: Any = None + ) -> Any: headers = headers or {} - options = { + options: dict[str, Any] = { "method": method.upper(), "headers": headers, } @@ -58,36 +137,15 @@ async def request(self, method, url, headers=None, body=None): options["body"] = json.dumps(body) headers.setdefault("Content-Type", "application/json") - last_error = None - for attempt in range(MAX_RETRIES): + async def do_request() -> tuple[int, str, Any | None]: response = await self._fetch(url, self._to_js(options)) if response.ok: - return await parse_response(response) - - status = response.status + data = await _parse_response(response) + return response.status, "", data text = await response.text() + return response.status, text, None - if status not in RETRY_STATUS_CODES: - # Don't retry client errors (except 429) - raise HttpError( - f"HTTP {status}: {text}", - status_code=status, - details={"url": url, "method": method}, - ) - - last_error = HttpError( - f"HTTP {status}: {text}", - status_code=status, - details={"url": url, "method": method}, - ) - - if attempt < MAX_RETRIES - 1: - backoff = INITIAL_BACKOFF * (2**attempt) - logger.warning(f"HTTP {status}, retrying in {backoff}s " f"(attempt {attempt + 1}/{MAX_RETRIES})") - await asyncio.sleep(backoff) - - assert last_error is not None # Always set after at least one iteration - raise last_error + return await _retry_request(do_request, url, method) # ---------------------------- @@ -96,20 +154,23 @@ async def request(self, method, url, headers=None, body=None): class ServerHttpClient(HttpClient): - def __init__(self): + """HTTP client for server environment using aiohttp.""" + + def __init__(self) -> None: import aiohttp self._aiohttp = aiohttp - async def request(self, method, url, headers=None, body=None): + async def request( + self, method: str, url: str, headers: dict[str, str] | None = None, body: Any = None + ) -> Any: data = None if body is not None: data = json.dumps(body) headers = headers or {} headers.setdefault("Content-Type", "application/json") - last_error = None - for attempt in range(MAX_RETRIES): + async def do_request() -> tuple[int, str, Any | None]: async with self._aiohttp.ClientSession() as session: async with session.request( method=method.upper(), @@ -118,32 +179,12 @@ async def request(self, method, url, headers=None, body=None): data=data, ) as response: if response.status < 400: - return await parse_response(response) - - status = response.status + parsed = await _parse_response(response) + return response.status, "", parsed text = await response.text() + return response.status, text, None - if status not in RETRY_STATUS_CODES: - # Don't retry client errors (except 429) - raise HttpError( - f"HTTP {status}: {text}", - status_code=status, - details={"url": url, "method": method}, - ) - - last_error = HttpError( - f"HTTP {status}: {text}", - status_code=status, - details={"url": url, "method": method}, - ) - - if attempt < MAX_RETRIES - 1: - backoff = INITIAL_BACKOFF * (2**attempt) - logger.warning(f"HTTP {status}, retrying in {backoff}s " f"(attempt {attempt + 1}/{MAX_RETRIES})") - await asyncio.sleep(backoff) - - assert last_error is not None # Always set after at least one iteration - raise last_error + return await _retry_request(do_request, url, method) # ---------------------------- diff --git a/packages/polaris/polaris/polaris/core/rate_limiter.py b/packages/polaris/polaris/polaris/core/rate_limiter.py index 3739e24d..ccddc6fb 100644 --- a/packages/polaris/polaris/polaris/core/rate_limiter.py +++ b/packages/polaris/polaris/polaris/core/rate_limiter.py @@ -79,7 +79,15 @@ def _refill(self) -> None: @property def available_tokens(self) -> float: - """Return the current number of available tokens (approximate).""" + """Return approximate token count for monitoring/logging. + + Note: This returns the cached token count without refilling or locking. + The value may be stale by up to (1/rate) seconds. This is intentional + to avoid blocking on a property access. + + For accurate token acquisition, always use `acquire()` which properly + handles locking and refilling. + """ return self.tokens @classmethod diff --git a/packages/polaris/polaris/polaris/core/retry.py b/packages/polaris/polaris/polaris/core/retry.py index 5ed51fe5..4808f008 100644 --- a/packages/polaris/polaris/polaris/core/retry.py +++ b/packages/polaris/polaris/polaris/core/retry.py @@ -48,5 +48,8 @@ async def retry_async( on_retry(e, attempt + 1, backoff) await asyncio.sleep(backoff) - assert last_error is not None + # last_error is guaranteed to be set after at least one iteration + # since max_retries >= 1 and we only reach here if all attempts failed + if last_error is None: + raise RuntimeError("Retry loop completed without capturing an error") raise last_error diff --git a/packages/polaris/polaris/polaris/modules/constants.py b/packages/polaris/polaris/polaris/modules/constants.py index 03c66d94..7cd7dc16 100644 --- a/packages/polaris/polaris/polaris/modules/constants.py +++ b/packages/polaris/polaris/polaris/modules/constants.py @@ -10,6 +10,7 @@ class NodeType(str, Enum): CONTROL = "control" EXECUTOR = "executor" LOOP = "loop" + MATERIALIZER = "materializer" PLANNER = "planner" REASONING = "reasoning" TERMINAL = "terminal" @@ -41,6 +42,11 @@ class ErrorCode(str, Enum): UNKNOWN_NODE_TYPE = "unknown_node_type" LOOP_INVALID_OVER = "loop_invalid_over" LOOP_ITERATION_FAILED = "loop_iteration_failed" + MATERIALIZER_FAILED = "materializer_failed" + MATERIALIZER_NOT_FOUND = "materializer_not_found" + MATERIALIZER_INVALID_ARGS = "materializer_invalid_args" + PLANNER_INVALID_JSON = "planner_invalid_json" + PLANNER_SCHEMA_VALIDATION_FAILED = "planner_schema_validation_failed" TRAVERSE_INVALID_CONFIG = "traverse_invalid_config" TRAVERSE_FETCH_FAILED = "traverse_fetch_failed" diff --git a/packages/polaris/polaris/polaris/modules/handlers/__init__.py b/packages/polaris/polaris/polaris/modules/handlers/__init__.py index fda87634..de837a6f 100644 --- a/packages/polaris/polaris/polaris/modules/handlers/__init__.py +++ b/packages/polaris/polaris/polaris/modules/handlers/__init__.py @@ -7,6 +7,7 @@ from .control import ControlHandler from .executor import ExecutorHandler from .loop import LoopHandler +from .materializer import MaterializerHandler from .planner import PlannerHandler from .reasoning import ReasoningHandler from .terminal import TerminalHandler @@ -18,6 +19,7 @@ NodeType.CONTROL: ControlHandler, NodeType.EXECUTOR: ExecutorHandler, NodeType.LOOP: LoopHandler, + NodeType.MATERIALIZER: MaterializerHandler, NodeType.PLANNER: PlannerHandler, NodeType.REASONING: ReasoningHandler, NodeType.TERMINAL: TerminalHandler, @@ -57,6 +59,7 @@ def get_handler(node_type: str) -> NodeHandler | None: "ControlHandler", "ExecutorHandler", "LoopHandler", + "MaterializerHandler", "PlannerHandler", "ReasoningHandler", "TerminalHandler", diff --git a/packages/polaris/polaris/polaris/modules/handlers/materializer.py b/packages/polaris/polaris/polaris/modules/handlers/materializer.py new file mode 100644 index 00000000..663db204 --- /dev/null +++ b/packages/polaris/polaris/polaris/modules/handlers/materializer.py @@ -0,0 +1,123 @@ +"""Handler for materializer nodes. + +A materializer node invokes a pure Python function with explicit arguments. +It is deterministic, has no LLM calls, no branching on content, and no +side effects outside the optional workspace path. +""" + +import logging +import traceback +from typing import TYPE_CHECKING, Any + +import jsonschema + +from ..constants import ErrorCode +from ..materializers import catalog +from ..types import Context, NodeDefinition, Result + +if TYPE_CHECKING: + from ..registry import Registry + +logger = logging.getLogger(__name__) + + +class MaterializerHandler: + """Handler for materializer nodes. + + Executes a pre-registered Python function with resolved arguments. + The function must be registered in the materializer catalog before + the runtime is initialized. + """ + + async def execute( + self, + node: NodeDefinition, + ctx: Context, + registry: "Registry", + runner: Any, + ) -> Result: + """Execute the materializer node. + + Args: + node: The node definition containing target, args, workspace, input_schema + ctx: The execution context + registry: The registry (unused for materializers) + runner: The runner instance with resolver + + Returns: + Result dict with ok=True and result on success, + or ok=False with error on failure + """ + _ = registry # Materializers don't use the registry + + target = node.get("target") + args_spec = node.get("args", {}) + workspace_spec = node.get("workspace") + input_schema = node.get("input_schema") + + logger.debug(f"Materializer executing: {target}") + + # Get the materializer function from the catalog + try: + fn = catalog.get(target) + except KeyError as e: + logger.error(f"Materializer not found: {target}") + return { + "ok": False, + "error": { + "code": ErrorCode.MATERIALIZER_NOT_FOUND, + "message": str(e), + }, + } + + # Resolve arguments + args = {} + for key, value in args_spec.items(): + args[key] = runner.resolver.resolve(value, ctx) + + # Resolve and add workspace if specified + if workspace_spec is not None: + workspace = runner.resolver.resolve(workspace_spec, ctx) + args["workspace"] = workspace + + # Eager validation against input schema if provided + if input_schema is not None: + try: + jsonschema.validate(args, input_schema) + except jsonschema.ValidationError as e: + logger.error(f"Materializer argument validation failed: {e.message}") + return { + "ok": False, + "error": { + "code": ErrorCode.MATERIALIZER_INVALID_ARGS, + "message": f"Argument validation failed: {e.message}", + "details": {"path": list(e.path), "schema_path": list(e.schema_path)}, + }, + } + + # Execute the materializer function + try: + result = fn(**args) + logger.debug(f"Materializer {target} completed successfully") + + # Set result in context for emit rules + ctx["result"] = result + + # Apply emit rules if present + emit = node.get("emit") + if emit: + runner.resolver.apply_emit(emit, {"result": result}, ctx) + + return {"ok": True, "result": result} + + except Exception as e: + tb = traceback.format_exc() + logger.error(f"Materializer {target} failed: {e}\n{tb}") + return { + "ok": False, + "error": { + "code": ErrorCode.MATERIALIZER_FAILED, + "message": str(e), + "details": {"traceback": tb}, + }, + } diff --git a/packages/polaris/polaris/polaris/modules/handlers/planner.py b/packages/polaris/polaris/polaris/modules/handlers/planner.py index c27f700e..f68ccec6 100644 --- a/packages/polaris/polaris/polaris/modules/handlers/planner.py +++ b/packages/polaris/polaris/polaris/modules/handlers/planner.py @@ -1,8 +1,17 @@ -"""Handler for planner nodes.""" +"""Handler for planner nodes. +Planners emit structured JSON output, never tool calls. +- Route planners: emit {"route": } for control flow +- JSON planners: emit parameter objects for downstream computation +""" + +import json import logging from typing import TYPE_CHECKING, Any +import jsonschema + +from ..constants import ErrorCode from ..types import Context, NodeDefinition, Result if TYPE_CHECKING: @@ -11,8 +20,96 @@ logger = logging.getLogger(__name__) +def build_route_schema(routes: dict[str, Any]) -> dict[str, Any]: + """Build JSON schema for route enum. + + Args: + routes: Dict mapping route names to route specs + + Returns: + JSON Schema that validates {"route": } + """ + return { + "type": "object", + "required": ["route"], + "properties": {"route": {"enum": list(routes.keys())}}, + "additionalProperties": False, + } + + +class PlannerOutputShim: + """Parses and validates planner JSON output. Nothing else. + + The shim is stateless and has no knowledge of routing or control flow. + All control flow decisions are made by the runner based on the validated output. + """ + + def validate( + self, + raw_response: str, + schema: dict[str, Any], + ) -> Result: + """Parse JSON and validate against schema. + + Args: + raw_response: Raw string from LLM + schema: JSON Schema to validate against + + Returns: + Result with validated data or error + """ + # Step 1: Parse JSON + try: + data = json.loads(raw_response) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse planner JSON output: {e}") + return { + "ok": False, + "error": { + "code": ErrorCode.PLANNER_INVALID_JSON, + "message": f"Failed to parse JSON: {e.msg}", + "details": { + "position": e.pos, + "raw_truncated": raw_response[:200], + }, + }, + } + + # Step 2: Validate against schema + try: + jsonschema.validate(data, schema) + except jsonschema.ValidationError as e: + logger.error(f"Planner output failed schema validation: {e.message}") + return { + "ok": False, + "error": { + "code": ErrorCode.PLANNER_SCHEMA_VALIDATION_FAILED, + "message": f"Schema validation failed: {e.message}", + "details": { + "path": list(e.path), + "schema_path": list(e.schema_path), + "value": e.instance, + }, + }, + } + + # Step 3: Return validated data + logger.debug("Planner output validated successfully") + return {"ok": True, "result": data} + + class PlannerHandler: - """Handler for planner nodes.""" + """Handler for planner nodes. + + Planners never emit tool calls. All output is validated JSON: + - Route mode: {"route": } - used for control flow decisions + - JSON mode: parameter object - used for downstream computation + + The shim validates the output. The runner handles route-to-node mapping. + """ + + def __init__(self) -> None: + self.shim = PlannerOutputShim() async def execute( self, @@ -21,18 +118,81 @@ async def execute( registry: "Registry", runner: Any, ) -> Result: - tools = node.get("tools", []) - logger.debug("Planner node starting with %d tools", len(tools)) - planned = await registry.plan( - ctx, - { - "node": node, - "prompt": node.get("prompt", ""), - "tools": tools, - "output_schema": node.get("output_schema"), - }, - ) - logger.debug("Planner node completed") - ctx["result"] = planned - runner.resolver.apply_emit(node.get("emit"), {"result": planned}, ctx) - return {"ok": True, "result": planned} + output_mode = node.get("output_mode") + prompt = self._build_prompt(node, ctx, runner) + + # Build schema based on mode + if output_mode == "route": + schema = build_route_schema(node["routes"]) + else: + schema = node["output_schema"] + + logger.debug(f"Planner executing in {output_mode} mode") + + # Request structured JSON from LLM (no tool calls) + raw_response = await registry.reason_structured(prompt, schema) + + # Validate through shim (shim only validates, nothing else) + result = self.shim.validate(raw_response, schema) + + if not result["ok"]: + logger.warning(f"Planner validation failed: {result['error']}") + return result + + # Set result in context + ctx["result"] = result["result"] + + # Apply emit rules + emit = node.get("emit") + if emit: + runner.resolver.apply_emit(emit, {"result": result["result"]}, ctx) + + logger.debug(f"Planner completed: {result['result']}") + return result + + def _build_prompt( + self, + node: NodeDefinition, + ctx: Context, + runner: Any, + ) -> str: + """Build prompt with context and output instructions.""" + base_prompt = node.get("prompt", "") + output_mode = node.get("output_mode") + + # Resolve input if specified + input_spec = node.get("input") + if input_spec: + resolved_input = runner.resolver.resolve(input_spec, ctx) + else: + resolved_input = None + + if output_mode == "route": + routes = node["routes"] + options = "\n".join( + f'- "{name}": {spec["description"]}' + for name, spec in routes.items() + ) + prompt = f"""{base_prompt} + +Select exactly one route from the following options: +{options} + +Respond with valid JSON in this exact format: {{"route": ""}} +Do not include any other text, explanation, or formatting.""" + + else: + # JSON mode + prompt = f"""{base_prompt} + +Respond with valid JSON matching the required schema. +Do not include any other text, explanation, or formatting.""" + + # Add context if available + if resolved_input: + prompt = f"""{prompt} + +Context: +{json.dumps(resolved_input, indent=2)}""" + + return prompt diff --git a/packages/polaris/polaris/polaris/modules/materializers/__init__.py b/packages/polaris/polaris/polaris/modules/materializers/__init__.py new file mode 100644 index 00000000..d8de7759 --- /dev/null +++ b/packages/polaris/polaris/polaris/modules/materializers/__init__.py @@ -0,0 +1,25 @@ +"""Materializer catalog for pure Python function invocation. + +The catalog is loaded from entry points at framework initialization and frozen +before any pipeline executes. This ensures determinism and auditability. +""" + +from .catalog import ( + MaterializerCatalog, + freeze, + get, + is_frozen, + list_all, + load_entry_points, + register, +) + +__all__ = [ + "MaterializerCatalog", + "freeze", + "get", + "is_frozen", + "list_all", + "load_entry_points", + "register", +] diff --git a/packages/polaris/polaris/polaris/modules/materializers/catalog.py b/packages/polaris/polaris/polaris/modules/materializers/catalog.py new file mode 100644 index 00000000..3ba7fd6f --- /dev/null +++ b/packages/polaris/polaris/polaris/modules/materializers/catalog.py @@ -0,0 +1,156 @@ +"""Materializer catalog with freeze semantics. + +The catalog is populated at framework initialization via entry points, +then frozen before any pipeline executes. Once frozen, no new registrations +are allowed, ensuring that a given YAML always resolves to the same callable set. +""" + +import importlib.metadata +import logging +from typing import Any, Callable + +logger = logging.getLogger(__name__) + +# Type alias for materializer functions +MaterializerFn = Callable[..., Any] + + +class MaterializerCatalog: + """Registry of materializer functions with freeze semantics. + + Entry points are loaded at startup, then the catalog is frozen. + After freezing, no new registrations are allowed. + """ + + def __init__(self) -> None: + self._registry: dict[str, MaterializerFn] = {} + self._frozen: bool = False + + def register(self, name: str, fn: MaterializerFn) -> None: + """Register a materializer function. + + Args: + name: Stable identifier for the materializer + fn: The callable to register + + Raises: + RuntimeError: If the catalog is frozen + ValueError: If a materializer with this name already exists + """ + if self._frozen: + raise RuntimeError(f"Cannot register '{name}': catalog is frozen") + if name in self._registry: + raise ValueError(f"Materializer '{name}' already registered") + self._registry[name] = fn + logger.debug(f"Registered materializer: {name}") + + def get(self, name: str) -> MaterializerFn: + """Get a materializer function by name. + + Args: + name: The materializer identifier + + Returns: + The registered callable + + Raises: + KeyError: If no materializer with this name exists + """ + if name not in self._registry: + raise KeyError(f"Unknown materializer: '{name}'") + return self._registry[name] + + def freeze(self) -> None: + """Freeze the catalog, preventing further registrations.""" + self._frozen = True + logger.info(f"Materializer catalog frozen with {len(self._registry)} entries") + + def is_frozen(self) -> bool: + """Check if the catalog is frozen.""" + return self._frozen + + def list_all(self) -> list[str]: + """List all registered materializer names (for auditability).""" + return sorted(self._registry.keys()) + + def clear(self) -> None: + """Clear the catalog and reset frozen state. For testing only.""" + self._registry.clear() + self._frozen = False + + +# Module-level singleton +_catalog = MaterializerCatalog() + + +def register(name: str) -> Callable[[MaterializerFn], MaterializerFn]: + """Decorator for registering materializer functions. + + Usage: + @register("my_module.my_materializer") + def my_materializer(workspace: str, data: dict) -> str: + ... + + Args: + name: Stable identifier for the materializer + + Returns: + Decorator that registers the function + """ + + def decorator(fn: MaterializerFn) -> MaterializerFn: + _catalog.register(name, fn) + return fn + + return decorator + + +def get(name: str) -> MaterializerFn: + """Get a materializer function by name.""" + return _catalog.get(name) + + +def freeze() -> None: + """Freeze the catalog, preventing further registrations.""" + _catalog.freeze() + + +def is_frozen() -> bool: + """Check if the catalog is frozen.""" + return _catalog.is_frozen() + + +def list_all() -> list[str]: + """List all registered materializer names.""" + return _catalog.list_all() + + +def load_entry_points() -> None: + """Load all materializers from entry points, then freeze. + + This function discovers and loads materializer registration hooks from + installed packages via the 'polaris.materializers' entry point group. + After loading, the catalog is frozen to prevent runtime modifications. + + Raises: + RuntimeError: If the catalog is already frozen + """ + if _catalog.is_frozen(): + raise RuntimeError("Catalog already initialized") + + eps = importlib.metadata.entry_points(group="polaris.materializers") + for ep in eps: + logger.debug(f"Loading materializer entry point: {ep.name}") + try: + register_all = ep.load() + register_all() + except Exception as e: + logger.error(f"Failed to load materializer entry point '{ep.name}': {e}") + raise + + _catalog.freeze() + + +def _get_catalog() -> MaterializerCatalog: + """Get the singleton catalog instance. For testing only.""" + return _catalog diff --git a/packages/polaris/polaris/polaris/modules/registry.py b/packages/polaris/polaris/polaris/modules/registry.py index 94b738e9..4e5c35ba 100644 --- a/packages/polaris/polaris/polaris/modules/registry.py +++ b/packages/polaris/polaris/polaris/modules/registry.py @@ -200,7 +200,61 @@ async def plan(self, ctx, spec): return arguments # ---------------------------- - # Reason + # Reason Structured (JSON output) + # ---------------------------- + async def reason_structured(self, prompt: str, schema: dict) -> str: + """Request structured JSON output from LLM. + + This method requests JSON output from the LLM without using tool calls. + The response is returned as a raw string for the caller to parse and validate. + + Args: + prompt: The prompt including JSON output instructions + schema: JSON Schema describing the expected output (for context) + + Returns: + Raw JSON string response from LLM + """ + # Include schema in the prompt for guidance + schema_str = json.dumps(schema, indent=2) + messages = [ + { + "role": "system", + "content": ( + "You are a structured output assistant. " + "You MUST respond with valid JSON only, no additional text. " + "Your response must conform to the provided schema." + ), + }, + { + "role": "user", + "content": f"{prompt}\n\nRequired JSON Schema:\n{schema_str}", + }, + ] + + reply = await self._completions_post( + { + **self.config, + "messages": messages, + } + ) + + # Extract content from response + choices = reply.get("choices", []) + if not choices: + raise NodeExecutionError("LLM returned no choices") + + message = choices[0].get("message", {}) + content = message.get("content", "") + + if not content or not content.strip(): + raise NodeExecutionError("LLM returned empty content") + + # Return raw content - caller will parse and validate + return content.strip() + + # ---------------------------- + # Reason (text output) # ---------------------------- async def reason(self, prompt, input): messages = [ diff --git a/packages/polaris/polaris/polaris/modules/runner.py b/packages/polaris/polaris/polaris/modules/runner.py index 9a1495ab..ff247735 100644 --- a/packages/polaris/polaris/polaris/modules/runner.py +++ b/packages/polaris/polaris/polaris/modules/runner.py @@ -138,9 +138,25 @@ def _resolve_next(self, node: NodeDefinition, res: Result | None, ctx: Context) next_val = on_handlers["warning"] else: # Success case - if node.get("type") == NodeType.CONTROL: + node_type = node.get("type") + + if node_type == NodeType.CONTROL: + # Control nodes return next in result next_val = res.get("result", {}).get("next") if res else None + + elif node_type == NodeType.PLANNER and node.get("output_mode") == "route": + # Route planners: map enum value to next node + if res and res.get("ok") and res.get("result"): + route_value = res["result"].get("route") + routes = node.get("routes", {}) + if route_value and route_value in routes: + next_val = routes[route_value].get("next") + else: + logger.warning(f"Invalid route value: {route_value}") + next_val = None + else: + # All other nodes: use static 'next' nv = node.get("next") if isinstance(nv, dict): ctx["result"] = res.get("result") if res else None diff --git a/packages/polaris/polaris/polaris/modules/schema.py b/packages/polaris/polaris/polaris/modules/schema.py index 0c4ad90e..bc36b7cc 100644 --- a/packages/polaris/polaris/polaris/modules/schema.py +++ b/packages/polaris/polaris/polaris/modules/schema.py @@ -196,14 +196,86 @@ class ComputeNode(BaseNode): type: Literal["compute"] -class PlannerNode(BaseNode): - """Planner node - AI-powered planning with tool use.""" +class PlannerRouteSpec(BaseModel): + """Route definition with description and target node.""" + + description: str = Field(..., min_length=1) + next: str = Field(..., min_length=1) + + +class PlannerNode(BaseModel): + """Planner node - LLM-powered decision making with structured JSON output. + + Two modes: + - route: Emits {"route": } for control flow decisions. Routes map to next nodes. + - json: Emits a parameter object for downstream computation. Must have static next. + + Planners never emit tool calls. All output is validated JSON. + """ type: Literal["planner"] prompt: str = "" - tools: list[str] = Field(default_factory=list) + output_mode: Literal["route", "json"] + input: dict[str, Any] | None = None + + # Route mode: enum -> node mapping + routes: dict[str, PlannerRouteSpec] | None = None + + # JSON mode: parameter schema output_schema: dict[str, Any] | None = None + # Common fields + emit: dict[str, Any] | None = None + next: str | None = None + + model_config = {"extra": "forbid"} + + @model_validator(mode="after") + def validate_mode_config(self) -> "PlannerNode": + """Validate mode-specific configuration.""" + if self.output_mode == "route": + if self.routes is None: + raise ValueError("Route planners require 'routes'") + if not self.routes: + raise ValueError("Route planners require at least one route") + if self.next is not None: + raise ValueError( + "Route planners cannot have static 'next' - " + "next node is determined by route selection" + ) + if self.output_schema is not None: + raise ValueError("Route planners cannot have 'output_schema'") + + elif self.output_mode == "json": + if self.output_schema is None: + raise ValueError("JSON planners require 'output_schema'") + if self.next is None: + raise ValueError("JSON planners require static 'next'") + if self.routes is not None: + raise ValueError("JSON planners cannot have 'routes'") + + return self + + +class MaterializerNode(BaseNode): + """Materializer node - pure Python function invocation. + + Properties: + - Deterministic execution (no LLM calls, no branching on content) + - Explicit arguments resolved before invocation + - Optional workspace path for file I/O + - No side effects outside the workspace + - No authority over control flow + + The target must reference a pre-registered materializer in the catalog. + """ + + type: Literal["materializer"] + target: str = Field(..., min_length=1, description="Catalog identifier for the materializer") + args: dict[str, DynamicValue] = Field(default_factory=dict, description="Arguments to pass to the function") + workspace: DynamicValue | None = Field(None, description="Optional workspace path for file I/O") + input_schema: dict[str, Any] | None = Field(None, description="JSON Schema for eager argument validation") + # Union of all node types NodeDefinition = Annotated[ @@ -216,6 +288,7 @@ class PlannerNode(BaseNode): LoopNode, ComputeNode, PlannerNode, + MaterializerNode, ], Field(discriminator="type"), ] @@ -263,6 +336,53 @@ def validate_has_terminal(self) -> "AgentDefinition": raise ValueError("Agent must have at least one terminal node") return self + @model_validator(mode="after") + def validate_json_planner_targets(self) -> "AgentDefinition": + """Ensure JSON planners only connect to materializer or compute nodes.""" + allowed_targets = {"materializer", "compute"} + + for node_id, node in self.nodes.items(): + if not isinstance(node, PlannerNode): + continue + if node.output_mode != "json": + continue + if node.next is None: + continue # Caught by node-level validator + + next_node = self.nodes.get(node.next) + if next_node is None: + continue # Caught by existing validator + + next_type = getattr(next_node, "type", None) + if next_type not in allowed_targets: + raise ValueError( + f"JSON planner '{node_id}' connects to {next_type} " + f"node '{node.next}'. JSON planners may only feed: " + f"{', '.join(sorted(allowed_targets))}" + ) + + return self + + @model_validator(mode="after") + def validate_route_planner_targets(self) -> "AgentDefinition": + """Ensure all route planner targets reference existing nodes.""" + for node_id, node in self.nodes.items(): + if not isinstance(node, PlannerNode): + continue + if node.output_mode != "route": + continue + if node.routes is None: + continue # Caught by node-level validator + + for route_name, route_spec in node.routes.items(): + if route_spec.next not in self.nodes: + raise ValueError( + f"Route planner '{node_id}' route '{route_name}' " + f"references non-existent node '{route_spec.next}'" + ) + + return self + def validate_agent(data: dict[str, Any]) -> AgentDefinition: """Validate an agent definition dictionary and return a validated model. diff --git a/packages/polaris/polaris/polaris/runtime.py b/packages/polaris/polaris/polaris/runtime.py index 0c32e202..3c95b9f8 100644 --- a/packages/polaris/polaris/polaris/runtime.py +++ b/packages/polaris/polaris/polaris/runtime.py @@ -1,6 +1,47 @@ +import logging + +from polaris.modules.materializers import catalog as materializer_catalog from polaris.modules.registry import Registry from polaris.modules.runner import ProgressCallback, Runner +logger = logging.getLogger(__name__) + +_initialized = False + + +def initialize() -> None: + """Initialize the Polaris runtime. + + This function must be called once before any agent definitions are loaded + or run() is called. It performs the following: + + 1. Loads materializer functions from entry points + 2. Freezes the materializer catalog + + After initialization, the materializer catalog is immutable, ensuring that + a given YAML always resolves to the same callable set. + + This function is idempotent - calling it multiple times has no effect + after the first call. + """ + global _initialized + if _initialized: + return + + materializer_catalog.load_entry_points() + _initialized = True + + materializers = materializer_catalog.list_all() + if materializers: + logger.info(f"Polaris initialized. Materializers: {materializers}") + else: + logger.info("Polaris initialized. No materializers registered.") + + +def is_initialized() -> bool: + """Check if the Polaris runtime has been initialized.""" + return _initialized + async def run(config, inputs, name, agents, on_progress: ProgressCallback | None = None): """Run an agent pipeline. @@ -14,7 +55,15 @@ async def run(config, inputs, name, agents, on_progress: ProgressCallback | None Returns: Result dict containing state and last node output + + Raises: + RuntimeError: If the runtime has not been initialized """ + if not materializer_catalog.is_frozen(): + raise RuntimeError( + "Polaris runtime not initialized. Call polaris.initialize() first." + ) + registry = Registry(config) await registry.init() registry.agents.register_agents(agents) diff --git a/packages/polaris/polaris/tests/handlers/mocks.py b/packages/polaris/polaris/tests/handlers/mocks.py index 77a18697..bf8f89d1 100644 --- a/packages/polaris/polaris/tests/handlers/mocks.py +++ b/packages/polaris/polaris/tests/handlers/mocks.py @@ -20,6 +20,7 @@ def __init__(self) -> None: self.call_api_result: dict[str, Any] = {"ok": True, "result": {"data": "test"}} self.plan_result: dict[str, Any] = {"next": "end"} self.reason_result: str = "reasoning output" + self.reason_structured_result: str = '{"route": "process"}' self.agents = MockAgentResolver() async def call_api(self, ctx: dict[str, Any], spec: dict[str, Any]) -> dict[str, Any]: @@ -34,6 +35,10 @@ async def reason(self, prompt: str, input: Any) -> str: _ = prompt, input # unused return self.reason_result + async def reason_structured(self, prompt: str, schema: dict[str, Any]) -> str: + _ = prompt, schema # unused + return self.reason_structured_result + class MockResolver: """Mock resolver for handler tests.""" diff --git a/packages/polaris/polaris/tests/handlers/test_materializer.py b/packages/polaris/polaris/tests/handlers/test_materializer.py new file mode 100644 index 00000000..e32c4a6e --- /dev/null +++ b/packages/polaris/polaris/tests/handlers/test_materializer.py @@ -0,0 +1,223 @@ +"""Tests for MaterializerHandler.""" + +import pytest + +from polaris.modules.constants import ErrorCode +from polaris.modules.handlers import MaterializerHandler +from polaris.modules.materializers.catalog import _get_catalog, register + +from .mocks import MockRegistry, MockRunner + + +@pytest.fixture(autouse=True) +def clean_catalog(): + """Clean the catalog before and after each test.""" + catalog = _get_catalog() + catalog.clear() + + # Register test materializers + @register("test.double") + def double(x: int) -> int: + return x * 2 + + @register("test.concat") + def concat(a: str, b: str) -> str: + return a + b + + @register("test.with_workspace") + def with_workspace(workspace: str, data: str) -> str: + return f"{workspace}:{data}" + + @register("test.raises") + def raises_error() -> None: + raise ValueError("Intentional error") + + catalog.freeze() + yield + catalog.clear() + + +class TestMaterializerHandler: + @pytest.mark.asyncio + async def test_execute_simple_function(self, mock_context): + """Test executing a simple materializer function.""" + handler = MaterializerHandler() + runner = MockRunner() + node = { + "type": "materializer", + "target": "test.double", + "args": {"x": 5}, + } + + result = await handler.execute(node, mock_context, MockRegistry(), runner) + + assert result["ok"] is True + assert result["result"] == 10 + + @pytest.mark.asyncio + async def test_execute_with_multiple_args(self, mock_context): + """Test executing with multiple arguments.""" + handler = MaterializerHandler() + runner = MockRunner() + node = { + "type": "materializer", + "target": "test.concat", + "args": {"a": "hello", "b": "world"}, + } + + result = await handler.execute(node, mock_context, MockRegistry(), runner) + + assert result["ok"] is True + assert result["result"] == "helloworld" + + @pytest.mark.asyncio + async def test_execute_with_workspace(self, mock_context): + """Test executing with workspace argument.""" + handler = MaterializerHandler() + runner = MockRunner() + node = { + "type": "materializer", + "target": "test.with_workspace", + "args": {"data": "test_data"}, + "workspace": "/tmp/workspace", + } + + result = await handler.execute(node, mock_context, MockRegistry(), runner) + + assert result["ok"] is True + assert result["result"] == "/tmp/workspace:test_data" + + @pytest.mark.asyncio + async def test_execute_unknown_target_fails(self, mock_context): + """Test that unknown target returns error.""" + handler = MaterializerHandler() + runner = MockRunner() + node = { + "type": "materializer", + "target": "nonexistent.materializer", + "args": {}, + } + + result = await handler.execute(node, mock_context, MockRegistry(), runner) + + assert result["ok"] is False + assert result["error"]["code"] == ErrorCode.MATERIALIZER_NOT_FOUND + + @pytest.mark.asyncio + async def test_execute_function_exception_fails(self, mock_context): + """Test that function exceptions are captured.""" + handler = MaterializerHandler() + runner = MockRunner() + node = { + "type": "materializer", + "target": "test.raises", + "args": {}, + } + + result = await handler.execute(node, mock_context, MockRegistry(), runner) + + assert result["ok"] is False + assert result["error"]["code"] == ErrorCode.MATERIALIZER_FAILED + assert "Intentional error" in result["error"]["message"] + assert "traceback" in result["error"]["details"] + + @pytest.mark.asyncio + async def test_execute_sets_result_in_context(self, mock_context): + """Test that result is set in context.""" + handler = MaterializerHandler() + runner = MockRunner() + node = { + "type": "materializer", + "target": "test.double", + "args": {"x": 7}, + } + + await handler.execute(node, mock_context, MockRegistry(), runner) + + assert mock_context["result"] == 14 + + @pytest.mark.asyncio + async def test_execute_applies_emit(self, mock_context): + """Test that emit rules are applied.""" + handler = MaterializerHandler() + runner = MockRunner() + node = { + "type": "materializer", + "target": "test.double", + "args": {"x": 3}, + "emit": {"state.doubled": "result"}, + } + + await handler.execute(node, mock_context, MockRegistry(), runner) + + assert len(runner.emitted) == 1 + + @pytest.mark.asyncio + async def test_execute_with_input_schema_valid(self, mock_context): + """Test that valid args pass input schema validation.""" + handler = MaterializerHandler() + runner = MockRunner() + node = { + "type": "materializer", + "target": "test.double", + "args": {"x": 10}, + "input_schema": { + "type": "object", + "required": ["x"], + "properties": { + "x": {"type": "integer"}, + }, + }, + } + + result = await handler.execute(node, mock_context, MockRegistry(), runner) + + assert result["ok"] is True + assert result["result"] == 20 + + @pytest.mark.asyncio + async def test_execute_with_input_schema_invalid(self, mock_context): + """Test that invalid args fail input schema validation.""" + handler = MaterializerHandler() + runner = MockRunner() + node = { + "type": "materializer", + "target": "test.double", + "args": {"x": "not_an_integer"}, + "input_schema": { + "type": "object", + "required": ["x"], + "properties": { + "x": {"type": "integer"}, + }, + }, + } + + result = await handler.execute(node, mock_context, MockRegistry(), runner) + + assert result["ok"] is False + assert result["error"]["code"] == ErrorCode.MATERIALIZER_INVALID_ARGS + + @pytest.mark.asyncio + async def test_execute_with_input_schema_missing_required(self, mock_context): + """Test that missing required args fail validation.""" + handler = MaterializerHandler() + runner = MockRunner() + node = { + "type": "materializer", + "target": "test.concat", + "args": {"a": "hello"}, # Missing 'b' + "input_schema": { + "type": "object", + "required": ["a", "b"], + "properties": { + "a": {"type": "string"}, + "b": {"type": "string"}, + }, + }, + } + + result = await handler.execute(node, mock_context, MockRegistry(), runner) + + assert result["ok"] is False + assert result["error"]["code"] == ErrorCode.MATERIALIZER_INVALID_ARGS diff --git a/packages/polaris/polaris/tests/handlers/test_planner.py b/packages/polaris/polaris/tests/handlers/test_planner.py index 8b0b6bf3..3bf5c5ff 100644 --- a/packages/polaris/polaris/tests/handlers/test_planner.py +++ b/packages/polaris/polaris/tests/handlers/test_planner.py @@ -1,39 +1,207 @@ """Tests for PlannerHandler.""" +import json + import pytest +from polaris.modules.constants import ErrorCode from polaris.modules.handlers import PlannerHandler from .mocks import MockRegistry, MockRunner -class TestPlannerHandler: +class TestPlannerHandlerRouteMode: + """Tests for route mode planner.""" + @pytest.mark.asyncio - async def test_execute_calls_plan(self, mock_context): + async def test_execute_route_mode_success(self, mock_context): + """Test successful route mode execution.""" handler = PlannerHandler() runner = MockRunner() registry = MockRegistry() + registry.reason_structured_result = '{"route": "process"}' + node = { "type": "planner", - "prompt": "Test prompt", - "tools": [], + "output_mode": "route", + "prompt": "Decide the next step", + "routes": { + "process": {"description": "Continue processing", "next": "process_node"}, + "skip": {"description": "Skip to end", "next": "done"}, + }, } result = await handler.execute(node, mock_context, registry, runner) assert result["ok"] is True - assert result["result"]["next"] == "end" + assert result["result"] == {"route": "process"} @pytest.mark.asyncio - async def test_execute_applies_emit(self, mock_context): + async def test_execute_route_mode_invalid_route(self, mock_context): + """Test route mode with invalid route value.""" handler = PlannerHandler() runner = MockRunner() registry = MockRegistry() + registry.reason_structured_result = '{"route": "invalid"}' + node = { "type": "planner", - "emit": {"state.plan": "result"}, + "output_mode": "route", + "prompt": "Decide", + "routes": { + "process": {"description": "Process", "next": "process_node"}, + }, + } + + result = await handler.execute(node, mock_context, registry, runner) + + assert result["ok"] is False + assert result["error"]["code"] == ErrorCode.PLANNER_SCHEMA_VALIDATION_FAILED + + @pytest.mark.asyncio + async def test_execute_route_mode_applies_emit(self, mock_context): + """Test that emit rules are applied in route mode.""" + handler = PlannerHandler() + runner = MockRunner() + registry = MockRegistry() + registry.reason_structured_result = '{"route": "process"}' + + node = { + "type": "planner", + "output_mode": "route", + "prompt": "Decide", + "routes": { + "process": {"description": "Process", "next": "process_node"}, + }, + "emit": {"state.decision": {"$ref": "result.route"}}, } await handler.execute(node, mock_context, registry, runner) assert len(runner.emitted) == 1 + + +class TestPlannerHandlerJsonMode: + """Tests for JSON mode planner.""" + + @pytest.mark.asyncio + async def test_execute_json_mode_success(self, mock_context): + """Test successful JSON mode execution.""" + handler = PlannerHandler() + runner = MockRunner() + registry = MockRegistry() + registry.reason_structured_result = '{"threshold": 0.5, "columns": ["a", "b"]}' + + node = { + "type": "planner", + "output_mode": "json", + "prompt": "Extract parameters", + "output_schema": { + "type": "object", + "required": ["threshold", "columns"], + "properties": { + "threshold": {"type": "number"}, + "columns": {"type": "array", "items": {"type": "string"}}, + }, + }, + "next": "process_node", + } + + result = await handler.execute(node, mock_context, registry, runner) + + assert result["ok"] is True + assert result["result"]["threshold"] == 0.5 + assert result["result"]["columns"] == ["a", "b"] + + @pytest.mark.asyncio + async def test_execute_json_mode_schema_violation(self, mock_context): + """Test JSON mode with schema violation.""" + handler = PlannerHandler() + runner = MockRunner() + registry = MockRegistry() + registry.reason_structured_result = '{"threshold": "not a number"}' + + node = { + "type": "planner", + "output_mode": "json", + "prompt": "Extract", + "output_schema": { + "type": "object", + "required": ["threshold"], + "properties": {"threshold": {"type": "number"}}, + }, + "next": "process_node", + } + + result = await handler.execute(node, mock_context, registry, runner) + + assert result["ok"] is False + assert result["error"]["code"] == ErrorCode.PLANNER_SCHEMA_VALIDATION_FAILED + + @pytest.mark.asyncio + async def test_execute_json_mode_invalid_json(self, mock_context): + """Test JSON mode with invalid JSON response.""" + handler = PlannerHandler() + runner = MockRunner() + registry = MockRegistry() + registry.reason_structured_result = "not valid json" + + node = { + "type": "planner", + "output_mode": "json", + "prompt": "Extract", + "output_schema": {"type": "object"}, + "next": "process_node", + } + + result = await handler.execute(node, mock_context, registry, runner) + + assert result["ok"] is False + assert result["error"]["code"] == ErrorCode.PLANNER_INVALID_JSON + + @pytest.mark.asyncio + async def test_execute_json_mode_sets_context(self, mock_context): + """Test that result is set in context.""" + handler = PlannerHandler() + runner = MockRunner() + registry = MockRegistry() + registry.reason_structured_result = '{"value": 42}' + + node = { + "type": "planner", + "output_mode": "json", + "prompt": "Extract", + "output_schema": { + "type": "object", + "properties": {"value": {"type": "number"}}, + }, + "next": "process_node", + } + + await handler.execute(node, mock_context, registry, runner) + + assert mock_context["result"] == {"value": 42} + + @pytest.mark.asyncio + async def test_execute_with_input_context(self, mock_context): + """Test that input is resolved and passed to prompt.""" + handler = PlannerHandler() + runner = MockRunner() + registry = MockRegistry() + registry.reason_structured_result = '{"route": "process"}' + + mock_context["state"]["data"] = {"key": "value"} + + node = { + "type": "planner", + "output_mode": "route", + "prompt": "Analyze data", + "input": {"data": {"$ref": "state.data"}}, + "routes": { + "process": {"description": "Process", "next": "process_node"}, + }, + } + + result = await handler.execute(node, mock_context, registry, runner) + + assert result["ok"] is True diff --git a/packages/polaris/polaris/tests/handlers/test_planner_shim.py b/packages/polaris/polaris/tests/handlers/test_planner_shim.py new file mode 100644 index 00000000..d5bd3437 --- /dev/null +++ b/packages/polaris/polaris/tests/handlers/test_planner_shim.py @@ -0,0 +1,170 @@ +"""Tests for PlannerOutputShim.""" + +import pytest + +from polaris.modules.constants import ErrorCode +from polaris.modules.handlers.planner import PlannerOutputShim, build_route_schema + + +class TestPlannerOutputShim: + """Tests for the planner output shim.""" + + def setup_method(self): + self.shim = PlannerOutputShim() + + def test_validate_valid_json(self): + """Test validation of valid JSON against schema.""" + schema = { + "type": "object", + "required": ["name"], + "properties": {"name": {"type": "string"}}, + } + raw = '{"name": "test"}' + + result = self.shim.validate(raw, schema) + + assert result["ok"] is True + assert result["result"] == {"name": "test"} + + def test_validate_invalid_json_syntax(self): + """Test handling of invalid JSON syntax.""" + schema = {"type": "object"} + raw = "not valid json" + + result = self.shim.validate(raw, schema) + + assert result["ok"] is False + assert result["error"]["code"] == ErrorCode.PLANNER_INVALID_JSON + assert "raw_truncated" in result["error"]["details"] + + def test_validate_schema_violation(self): + """Test handling of schema validation failure.""" + schema = { + "type": "object", + "required": ["name"], + "properties": {"name": {"type": "string"}}, + } + raw = '{"name": 123}' # Wrong type + + result = self.shim.validate(raw, schema) + + assert result["ok"] is False + assert result["error"]["code"] == ErrorCode.PLANNER_SCHEMA_VALIDATION_FAILED + assert "path" in result["error"]["details"] + + def test_validate_missing_required_field(self): + """Test handling of missing required field.""" + schema = { + "type": "object", + "required": ["name", "value"], + "properties": { + "name": {"type": "string"}, + "value": {"type": "number"}, + }, + } + raw = '{"name": "test"}' # Missing 'value' + + result = self.shim.validate(raw, schema) + + assert result["ok"] is False + assert result["error"]["code"] == ErrorCode.PLANNER_SCHEMA_VALIDATION_FAILED + + def test_validate_route_enum(self): + """Test validation of route enum value.""" + schema = build_route_schema({ + "process": {"description": "Continue", "next": "process_node"}, + "skip": {"description": "Skip", "next": "done"}, + }) + raw = '{"route": "process"}' + + result = self.shim.validate(raw, schema) + + assert result["ok"] is True + assert result["result"] == {"route": "process"} + + def test_validate_invalid_route_enum(self): + """Test rejection of invalid route enum value.""" + schema = build_route_schema({ + "process": {"description": "Continue", "next": "process_node"}, + "skip": {"description": "Skip", "next": "done"}, + }) + raw = '{"route": "invalid"}' + + result = self.shim.validate(raw, schema) + + assert result["ok"] is False + assert result["error"]["code"] == ErrorCode.PLANNER_SCHEMA_VALIDATION_FAILED + + def test_validate_complex_schema(self): + """Test validation of complex nested schema.""" + schema = { + "type": "object", + "required": ["config"], + "properties": { + "config": { + "type": "object", + "required": ["threshold"], + "properties": { + "threshold": {"type": "number", "minimum": 0, "maximum": 1}, + "tags": {"type": "array", "items": {"type": "string"}}, + }, + }, + }, + } + raw = '{"config": {"threshold": 0.5, "tags": ["a", "b"]}}' + + result = self.shim.validate(raw, schema) + + assert result["ok"] is True + assert result["result"]["config"]["threshold"] == 0.5 + assert result["result"]["config"]["tags"] == ["a", "b"] + + def test_validate_additional_properties_rejected(self): + """Test that additional properties are rejected when schema forbids them.""" + schema = { + "type": "object", + "required": ["route"], + "properties": {"route": {"enum": ["a", "b"]}}, + "additionalProperties": False, + } + raw = '{"route": "a", "extra": "field"}' + + result = self.shim.validate(raw, schema) + + assert result["ok"] is False + assert result["error"]["code"] == ErrorCode.PLANNER_SCHEMA_VALIDATION_FAILED + + +class TestBuildRouteSchema: + """Tests for build_route_schema helper.""" + + def test_build_route_schema_single_route(self): + """Test schema generation for single route.""" + routes = {"process": {"description": "Process data", "next": "process_node"}} + + schema = build_route_schema(routes) + + assert schema["type"] == "object" + assert schema["required"] == ["route"] + assert schema["properties"]["route"]["enum"] == ["process"] + assert schema["additionalProperties"] is False + + def test_build_route_schema_multiple_routes(self): + """Test schema generation for multiple routes.""" + routes = { + "process": {"description": "Process", "next": "process_node"}, + "skip": {"description": "Skip", "next": "skip_node"}, + "retry": {"description": "Retry", "next": "retry_node"}, + } + + schema = build_route_schema(routes) + + assert set(schema["properties"]["route"]["enum"]) == {"process", "skip", "retry"} + + def test_build_route_schema_empty_routes(self): + """Test schema generation for empty routes.""" + routes = {} + + schema = build_route_schema(routes) + + assert schema["properties"]["route"]["enum"] == [] diff --git a/packages/polaris/polaris/tests/test_constants.py b/packages/polaris/polaris/tests/test_constants.py index 724f7c1b..13b760c1 100644 --- a/packages/polaris/polaris/tests/test_constants.py +++ b/packages/polaris/polaris/tests/test_constants.py @@ -11,7 +11,17 @@ class TestNodeType: def test_all_node_types_defined(self): - expected = ["compute", "control", "executor", "planner", "reasoning", "terminal"] + expected = [ + "compute", + "control", + "executor", + "loop", + "materializer", + "planner", + "reasoning", + "terminal", + "traverse", + ] for node_type in expected: assert node_type in [n.value for n in NodeType] @@ -19,9 +29,12 @@ def test_node_type_values(self): assert NodeType.COMPUTE == "compute" assert NodeType.CONTROL == "control" assert NodeType.EXECUTOR == "executor" + assert NodeType.LOOP == "loop" + assert NodeType.MATERIALIZER == "materializer" assert NodeType.PLANNER == "planner" assert NodeType.REASONING == "reasoning" assert NodeType.TERMINAL == "terminal" + assert NodeType.TRAVERSE == "traverse" def test_node_type_is_string_enum(self): assert isinstance(NodeType.COMPUTE, str) @@ -54,6 +67,15 @@ def test_all_error_codes_defined(self): "subagent_failed", "unknown_executor_op", "unknown_node_type", + "loop_invalid_over", + "loop_iteration_failed", + "materializer_failed", + "materializer_not_found", + "materializer_invalid_args", + "planner_invalid_json", + "planner_schema_validation_failed", + "traverse_invalid_config", + "traverse_fetch_failed", ] for code in expected: assert code in [e.value for e in ErrorCode] @@ -65,6 +87,11 @@ def test_error_code_values(self): assert ErrorCode.SUBAGENT_FAILED == "subagent_failed" assert ErrorCode.UNKNOWN_EXECUTOR_OP == "unknown_executor_op" assert ErrorCode.UNKNOWN_NODE_TYPE == "unknown_node_type" + assert ErrorCode.MATERIALIZER_FAILED == "materializer_failed" + assert ErrorCode.MATERIALIZER_NOT_FOUND == "materializer_not_found" + assert ErrorCode.MATERIALIZER_INVALID_ARGS == "materializer_invalid_args" + assert ErrorCode.PLANNER_INVALID_JSON == "planner_invalid_json" + assert ErrorCode.PLANNER_SCHEMA_VALIDATION_FAILED == "planner_schema_validation_failed" class TestMaxNodes: diff --git a/packages/polaris/polaris/tests/test_materializer_catalog.py b/packages/polaris/polaris/tests/test_materializer_catalog.py new file mode 100644 index 00000000..0de0841f --- /dev/null +++ b/packages/polaris/polaris/tests/test_materializer_catalog.py @@ -0,0 +1,176 @@ +"""Tests for the materializer catalog.""" + +import pytest + +from polaris.modules.materializers.catalog import MaterializerCatalog + + +class TestMaterializerCatalog: + """Tests for MaterializerCatalog.""" + + def test_register_and_get(self): + """Test registering and retrieving a materializer.""" + catalog = MaterializerCatalog() + + def my_fn(x: int) -> int: + return x * 2 + + catalog.register("test.my_fn", my_fn) + retrieved = catalog.get("test.my_fn") + + assert retrieved is my_fn + assert retrieved(5) == 10 + + def test_register_duplicate_raises(self): + """Test that registering a duplicate name raises.""" + catalog = MaterializerCatalog() + + def fn1(): + pass + + def fn2(): + pass + + catalog.register("test.fn", fn1) + + with pytest.raises(ValueError, match="already registered"): + catalog.register("test.fn", fn2) + + def test_get_unknown_raises(self): + """Test that getting an unknown materializer raises.""" + catalog = MaterializerCatalog() + + with pytest.raises(KeyError, match="Unknown materializer"): + catalog.get("nonexistent") + + def test_freeze_prevents_registration(self): + """Test that freeze prevents further registrations.""" + catalog = MaterializerCatalog() + catalog.freeze() + + def fn(): + pass + + with pytest.raises(RuntimeError, match="catalog is frozen"): + catalog.register("test.fn", fn) + + def test_is_frozen(self): + """Test is_frozen returns correct state.""" + catalog = MaterializerCatalog() + + assert catalog.is_frozen() is False + catalog.freeze() + assert catalog.is_frozen() is True + + def test_list_all(self): + """Test list_all returns sorted names.""" + catalog = MaterializerCatalog() + + catalog.register("z.last", lambda: None) + catalog.register("a.first", lambda: None) + catalog.register("m.middle", lambda: None) + + names = catalog.list_all() + + assert names == ["a.first", "m.middle", "z.last"] + + def test_clear_resets_catalog(self): + """Test clear resets the catalog for testing.""" + catalog = MaterializerCatalog() + + catalog.register("test.fn", lambda: None) + catalog.freeze() + + catalog.clear() + + assert catalog.is_frozen() is False + assert catalog.list_all() == [] + + def test_get_after_freeze_works(self): + """Test that get still works after freeze.""" + catalog = MaterializerCatalog() + + def fn(): + return "result" + + catalog.register("test.fn", fn) + catalog.freeze() + + retrieved = catalog.get("test.fn") + assert retrieved() == "result" + + +class TestModuleLevelFunctions: + """Tests for module-level convenience functions.""" + + def test_register_decorator(self): + """Test the register decorator.""" + from polaris.modules.materializers.catalog import _get_catalog, register + + catalog = _get_catalog() + catalog.clear() + + @register("test.decorated_fn") + def decorated_fn(x: int) -> int: + return x + 1 + + assert "test.decorated_fn" in catalog.list_all() + assert catalog.get("test.decorated_fn")(10) == 11 + + catalog.clear() + + def test_get_function(self): + """Test the module-level get function.""" + from polaris.modules.materializers.catalog import _get_catalog, get, register + + catalog = _get_catalog() + catalog.clear() + + @register("test.get_test") + def fn(): + return "hello" + + retrieved = get("test.get_test") + assert retrieved() == "hello" + + catalog.clear() + + def test_freeze_and_is_frozen(self): + """Test module-level freeze and is_frozen.""" + from polaris.modules.materializers.catalog import ( + _get_catalog, + freeze, + is_frozen, + ) + + catalog = _get_catalog() + catalog.clear() + + assert is_frozen() is False + freeze() + assert is_frozen() is True + + catalog.clear() + + def test_list_all_function(self): + """Test module-level list_all.""" + from polaris.modules.materializers.catalog import ( + _get_catalog, + list_all, + register, + ) + + catalog = _get_catalog() + catalog.clear() + + @register("b.second") + def fn1(): + pass + + @register("a.first") + def fn2(): + pass + + assert list_all() == ["a.first", "b.second"] + + catalog.clear() diff --git a/packages/polaris/polaris/tests/test_runtime.py b/packages/polaris/polaris/tests/test_runtime.py index 6d420e23..1e3f0d34 100644 --- a/packages/polaris/polaris/tests/test_runtime.py +++ b/packages/polaris/polaris/tests/test_runtime.py @@ -1,32 +1,23 @@ import pytest +from polaris.modules.materializers.catalog import _get_catalog from polaris.runtime import run +@pytest.fixture(autouse=True) +def initialize_runtime(): + """Initialize the materializer catalog for tests.""" + catalog = _get_catalog() + catalog.clear() + catalog.freeze() + yield + catalog.clear() + + @pytest.mark.asyncio async def test_run_minimal_agent(monkeypatch): - async def fake_completions_post(payload): - return { - "choices": [ - { - "message": { - "tool_calls": [ - { - "function": { - "name": "route", - "arguments": '{"next": "end"}', - } - } - ] - } - } - ] - } - - monkeypatch.setattr( - "polaris.modules.registry.completions_post", - fake_completions_post, - ) + async def fake_reason_structured(self, prompt, schema): + return '{"route": "done"}' async def fake_load_providers(config): return [] @@ -40,8 +31,17 @@ async def fake_load_providers(config): "nodes": { "start": { "type": "planner", - "next": None, - } + "output_mode": "route", + "prompt": "Decide next step", + "routes": { + "done": {"description": "End workflow", "next": "end"}, + }, + "emit": {"state.decision": {"$ref": "result.route"}}, + }, + "end": { + "type": "terminal", + "output": {"decision": {"$ref": "state.decision"}}, + }, }, "start": "start", } @@ -54,7 +54,20 @@ async def fake_load_providers(config): } agents = {"test_agent": agent} - result = await run(config, inputs, "test_agent", agents) + # Patch after agent is defined since Registry is created during run() + original_run = run + + async def patched_run(config, inputs, agent_id, agents, workspace=None): + from polaris.modules.registry import Registry + + original_reason_structured = Registry.reason_structured + Registry.reason_structured = fake_reason_structured + try: + return await original_run(config, inputs, agent_id, agents, workspace) + finally: + Registry.reason_structured = original_reason_structured + + result = await patched_run(config, inputs, "test_agent", agents) - assert result["last"]["result"]["next"] == "end" assert result["last"]["ok"] is True + assert result["last"]["result"]["decision"] == "done" diff --git a/packages/polaris/polaris/tests/test_schema.py b/packages/polaris/polaris/tests/test_schema.py index 6d639941..2e9bd44c 100644 --- a/packages/polaris/polaris/tests/test_schema.py +++ b/packages/polaris/polaris/tests/test_schema.py @@ -10,7 +10,9 @@ ExecutorNode, InputSpec, LoopNode, + MaterializerNode, PlannerNode, + PlannerRouteSpec, ReasoningNode, StateSpec, TerminalNode, @@ -241,21 +243,229 @@ def test_compute_with_emit(self): class TestPlannerNode: - def test_minimal_planner(self): - node = PlannerNode(type="planner") + def test_route_planner_minimal(self): + """Test minimal route planner configuration.""" + node = PlannerNode( + type="planner", + output_mode="route", + routes={ + "process": {"description": "Continue processing", "next": "process_node"}, + }, + ) + assert node.type == "planner" + assert node.output_mode == "route" + assert "process" in node.routes + + def test_route_planner_multiple_routes(self): + """Test route planner with multiple routes.""" + node = PlannerNode( + type="planner", + output_mode="route", + prompt="Decide the next step", + routes={ + "process": {"description": "Continue", "next": "process_node"}, + "skip": {"description": "Skip", "next": "skip_node"}, + "retry": {"description": "Retry", "next": "retry_node"}, + }, + ) + assert len(node.routes) == 3 + assert node.routes["process"].next == "process_node" + + def test_route_planner_rejects_static_next(self): + """Test that route planners cannot have static next.""" + with pytest.raises(ValidationError, match="cannot have static 'next'"): + PlannerNode( + type="planner", + output_mode="route", + routes={"process": {"description": "Process", "next": "node"}}, + next="some_node", + ) + + def test_route_planner_rejects_output_schema(self): + """Test that route planners cannot have output_schema.""" + with pytest.raises(ValidationError, match="cannot have 'output_schema'"): + PlannerNode( + type="planner", + output_mode="route", + routes={"process": {"description": "Process", "next": "node"}}, + output_schema={"type": "object"}, + ) + + def test_route_planner_requires_routes(self): + """Test that route planners require routes.""" + with pytest.raises(ValidationError, match="require 'routes'"): + PlannerNode( + type="planner", + output_mode="route", + ) + + def test_route_planner_requires_nonempty_routes(self): + """Test that route planners require at least one route.""" + with pytest.raises(ValidationError, match="at least one route"): + PlannerNode( + type="planner", + output_mode="route", + routes={}, + ) + + def test_json_planner_minimal(self): + """Test minimal JSON planner configuration.""" + node = PlannerNode( + type="planner", + output_mode="json", + output_schema={"type": "object", "properties": {"value": {"type": "number"}}}, + next="process_node", + ) assert node.type == "planner" - assert node.prompt == "" - assert node.tools == [] + assert node.output_mode == "json" + assert node.next == "process_node" + + def test_json_planner_requires_output_schema(self): + """Test that JSON planners require output_schema.""" + with pytest.raises(ValidationError, match="require 'output_schema'"): + PlannerNode( + type="planner", + output_mode="json", + next="process_node", + ) - def test_planner_with_options(self): + def test_json_planner_requires_static_next(self): + """Test that JSON planners require static next.""" + with pytest.raises(ValidationError, match="require static 'next'"): + PlannerNode( + type="planner", + output_mode="json", + output_schema={"type": "object"}, + ) + + def test_json_planner_rejects_routes(self): + """Test that JSON planners cannot have routes.""" + with pytest.raises(ValidationError, match="cannot have 'routes'"): + PlannerNode( + type="planner", + output_mode="json", + output_schema={"type": "object"}, + next="process_node", + routes={"process": {"description": "Process", "next": "node"}}, + ) + + def test_planner_with_input(self): + """Test planner with input configuration.""" node = PlannerNode( type="planner", - prompt="Plan the analysis", - tools=["search", "fetch"], - output_schema={"type": "object"}, + output_mode="route", + prompt="Analyze this data", + input={"data": {"$ref": "state.data"}}, + routes={"process": {"description": "Process", "next": "node"}}, ) - assert node.prompt == "Plan the analysis" - assert node.tools == ["search", "fetch"] + assert node.input is not None + assert "data" in node.input + + +class TestMaterializerNode: + def test_minimal_materializer(self): + node = MaterializerNode( + type="materializer", + target="my_module.my_function", + ) + assert node.type == "materializer" + assert node.target == "my_module.my_function" + assert node.args == {} + assert node.workspace is None + assert node.input_schema is None + + def test_materializer_with_args(self): + from polaris.modules.schema import RefExpr + + node = MaterializerNode( + type="materializer", + target="my_module.transform", + args={ + "data": {"$ref": "state.data"}, + "format": "json", + }, + ) + assert node.args["format"] == "json" + assert isinstance(node.args["data"], RefExpr) + assert node.args["data"].ref == "state.data" + + def test_materializer_with_workspace(self): + node = MaterializerNode( + type="materializer", + target="my_module.generate_file", + args={"template": "report"}, + workspace="/tmp/workspace", + ) + assert node.workspace == "/tmp/workspace" + + def test_materializer_with_dynamic_workspace(self): + from polaris.modules.schema import RefExpr + + node = MaterializerNode( + type="materializer", + target="my_module.generate_file", + workspace={"$ref": "config.workspace_path"}, + ) + assert isinstance(node.workspace, RefExpr) + assert node.workspace.ref == "config.workspace_path" + + def test_materializer_with_input_schema(self): + node = MaterializerNode( + type="materializer", + target="my_module.validate_data", + args={"data": {"$ref": "state.data"}}, + input_schema={ + "type": "object", + "required": ["data"], + "properties": { + "data": {"type": "array"}, + }, + }, + ) + assert node.input_schema["required"] == ["data"] + + def test_materializer_with_emit_and_next(self): + node = MaterializerNode( + type="materializer", + target="my_module.transform", + args={"input": "value"}, + emit={"state.output": {"$ref": "result"}}, + next="done", + ) + assert "state.output" in node.emit + assert node.next == "done" + + def test_materializer_rejects_empty_target(self): + with pytest.raises(ValidationError): + MaterializerNode( + type="materializer", + target="", + ) + + def test_materializer_in_agent_definition(self): + """Test that materializer nodes work in a full agent definition.""" + agent = validate_agent({ + "version": 1, + "id": "test_agent", + "start": "materialize", + "nodes": { + "materialize": { + "type": "materializer", + "target": "my_module.generate_report", + "args": { + "data": {"$ref": "state.data"}, + }, + "emit": {"state.report": {"$ref": "result"}}, + "next": "done", + }, + "done": { + "type": "terminal", + "output": {"report": {"$ref": "state.report"}}, + }, + }, + }) + assert isinstance(agent.nodes["materialize"], MaterializerNode) + assert agent.nodes["materialize"].target == "my_module.generate_report" class TestAgentDefinition: diff --git a/packages/polaris/polaris_dataset_report/polaris_dataset_report/__init__.py b/packages/polaris/polaris_dataset_report/polaris_dataset_report/__init__.py index 4b6188f1..3b927705 100644 --- a/packages/polaris/polaris_dataset_report/polaris_dataset_report/__init__.py +++ b/packages/polaris/polaris_dataset_report/polaris_dataset_report/__init__.py @@ -9,11 +9,10 @@ import yaml +import polaris from polaris import run as polaris_run from polaris.modules.runner import ProgressCallback -from .postprocess import postprocess - # Load agent definition from bundled YAML _AGENT_PATH = Path(__file__).parent / "agent.yml" _AGENT_NAME = "dataset_report" @@ -40,13 +39,10 @@ async def run( Returns: Result dict containing report, workflow analysis, and mermaid diagram """ + # Initialize framework before loading agent definitions + polaris.initialize() + agent = _load_agent() agents = {_AGENT_NAME: agent} - result = await polaris_run(config, inputs, _AGENT_NAME, agents, on_progress) - - # Apply postprocessing to add mermaid diagram - if result.get("last"): - result["last"] = postprocess(result["last"], result.get("state", {})) - - return result + return await polaris_run(config, inputs, _AGENT_NAME, agents, on_progress) diff --git a/packages/polaris/polaris_dataset_report/polaris_dataset_report/agent.yml b/packages/polaris/polaris_dataset_report/polaris_dataset_report/agent.yml index 83d1f08a..65f8fd61 100644 --- a/packages/polaris/polaris_dataset_report/polaris_dataset_report/agent.yml +++ b/packages/polaris/polaris_dataset_report/polaris_dataset_report/agent.yml @@ -28,6 +28,8 @@ state: type: string report: type: string + mermaid_diagram: + type: string nodes: # Step 1: Fetch source dataset to seed the traversal @@ -191,6 +193,22 @@ nodes: emit: state.report: $ref: result + next: generate_diagram + + # Step 7: Generate mermaid diagram from lineage + generate_diagram: + type: materializer + target: dataset_report.generate_mermaid + args: + dataset_details: + $ref: state.dataset_details + job_details: + $ref: state.job_details + source_dataset_id: + $ref: inputs.dataset_id + emit: + state.mermaid_diagram: + $ref: result next: done # Terminal node - return the generated report @@ -220,3 +238,5 @@ nodes: $ref: state.workflow_analysis report: $ref: state.report + mermaid_diagram: + $ref: state.mermaid_diagram diff --git a/packages/polaris/polaris_dataset_report/polaris_dataset_report/cli.py b/packages/polaris/polaris_dataset_report/polaris_dataset_report/cli.py index 7b00f980..198aeaf8 100644 --- a/packages/polaris/polaris_dataset_report/polaris_dataset_report/cli.py +++ b/packages/polaris/polaris_dataset_report/polaris_dataset_report/cli.py @@ -6,8 +6,6 @@ import logging import sys -from .postprocess import postprocess - def parse_inputs(input_args: list[str]) -> dict: """Parse key=value input arguments into a dictionary.""" @@ -38,10 +36,14 @@ def setup_logging(verbose: bool = False) -> None: async def run_agent(inputs: dict, verbose: bool = False) -> dict: """Run the dataset_report agent.""" + import polaris from polaris.config import load_config from . import run + # Initialize framework before loading config or agent definitions + polaris.initialize() + config = load_config().to_dict() def on_progress(event): diff --git a/packages/polaris/polaris_dataset_report/polaris_dataset_report/materializers.py b/packages/polaris/polaris_dataset_report/polaris_dataset_report/materializers.py new file mode 100644 index 00000000..1d47b022 --- /dev/null +++ b/packages/polaris/polaris_dataset_report/polaris_dataset_report/materializers.py @@ -0,0 +1,45 @@ +"""Materializer functions for the dataset report agent. + +These functions are registered with the Polaris materializer catalog +via the entry point mechanism at framework initialization. +""" + +from typing import Any + +from polaris.modules.materializers import register + +from .postprocess import generate_mermaid as _generate_mermaid + + +@register("dataset_report.generate_mermaid") +def generate_mermaid( + dataset_details: list[dict[str, Any]], + job_details: list[dict[str, Any]], + source_dataset_id: str | None = None, +) -> str: + """Generate a Mermaid flowchart from dataset and job details. + + This materializer wraps the existing generate_mermaid function from + postprocess.py, making it available for use in agent YAML definitions. + + Args: + dataset_details: List of dataset detail dicts with id, uuid, name, file_ext, creating_job + job_details: List of job detail dicts with id, tool_id, inputs, outputs, create_time + source_dataset_id: Optional ID of the source dataset to highlight + + Returns: + Mermaid diagram string + """ + return _generate_mermaid(dataset_details, job_details, source_dataset_id) + + +def register_all() -> None: + """Register all materializers for this package. + + This function is called by the Polaris framework via the entry point + mechanism during initialization. The actual registration happens via + the @register decorator when the module is imported. + """ + # All materializers are registered via decorators when this module is imported. + # This function exists to satisfy the entry point interface. + pass diff --git a/packages/polaris/polaris_dataset_report/pyproject.toml b/packages/polaris/polaris_dataset_report/pyproject.toml index 731c0469..fcafd230 100644 --- a/packages/polaris/polaris_dataset_report/pyproject.toml +++ b/packages/polaris/polaris_dataset_report/pyproject.toml @@ -15,6 +15,9 @@ dependencies = [ [project.scripts] dataset-report = "polaris_dataset_report.cli:main" +[project.entry-points."polaris.materializers"] +dataset_report = "polaris_dataset_report.materializers:register_all" + [tool.setuptools.packages.find] where = ["."]