diff --git a/.gitignore b/.gitignore index 7f97c7d6a8..70a6fa6e4c 100644 --- a/.gitignore +++ b/.gitignore @@ -203,6 +203,7 @@ cache /workspace/ openapi.json .client/ +/.agent_server_llm_switch_demo/ # Local workspace files .beads/*.db .worktrees/ diff --git a/docs/llm_switching.md b/docs/llm_switching.md new file mode 100644 index 0000000000..ba549a79aa --- /dev/null +++ b/docs/llm_switching.md @@ -0,0 +1,56 @@ +# Runtime LLM Switching (SDK + agent-server) + +This repo supports **runtime LLM switching** without any “agent immutability” or +resume-time diff enforcement. The guiding principle is: + +- An `Agent` is a *composition* of (effectively) immutable components (`LLM`, + `AgentContext`, etc.). +- The composition is **switchable**: at runtime the conversation can replace + components (currently the agent’s primary `LLM`) and persist the change. + +## Persistence model (single rule) + +Conversation snapshots (`base_state.json`) persist the agent’s LLM as: + +- `{"profile_id": ""}` when `LLM.profile_id` is present +- a full inline LLM payload when `LLM.profile_id` is absent + +Snapshots are written with `context={"expose_secrets": True}` so inline LLMs can +restore without “reconciliation” against a runtime agent. + +## SDK API + +Local conversations support two LLM update paths: + +- `LocalConversation.switch_llm(profile_id: str)`: + loads `.json` from the registry’s profile dir and swaps the active + LLM for the agent’s `usage_id`. +- `LocalConversation.set_llm(llm: LLM)`: + replaces the active LLM instance for the agent’s `usage_id` (useful for remote + clients that can’t rely on server-side profile files). + +Both are persisted immediately via the conversation’s base state snapshot. + +## agent-server API + +For a running (or paused/idle) conversation: + +- `POST /api/conversations/{conversation_id}/llm` + - `{"profile_id": ""}`: switch via server-side profile loading + - `{"llm": {...}}`: set an inline LLM payload (client-supplied config) + +There is also a convenience alias: + +- `POST /api/conversations/{conversation_id}/llm/switch` + - `{"profile_id": ""}` + +## Remote clients (VS Code extension) + +VS Code LLM Profiles are **local-only**. The recommended remote flow is: + +1. Resolve `profileId` locally to an LLM configuration. +2. Start the conversation with an expanded `agent.llm` payload (no `profile_id`). +3. On profile changes, call `POST /api/conversations/{id}/llm` with + `{"llm": }` so the server persists the new LLM. +4. On restore, the server’s persisted LLM is the source of truth; the client + can re-apply its selected profile before triggering a new run if desired. diff --git a/examples/01_standalone_sdk/26_runtime_llm_switch.py b/examples/01_standalone_sdk/26_runtime_llm_switch.py index 4cf5ed4de1..372b0bda0f 100644 --- a/examples/01_standalone_sdk/26_runtime_llm_switch.py +++ b/examples/01_standalone_sdk/26_runtime_llm_switch.py @@ -22,10 +22,7 @@ api_key = os.getenv("LLM_API_KEY") assert api_key is not None, "LLM_API_KEY environment variable is not set." -# 2. Disable inline conversations so profile references are stored instead -os.environ.setdefault("OPENHANDS_INLINE_CONVERSATIONS", "false") - -# 3. Profiles live under ~/.openhands/llm-profiles by default. We create two +# 2. Profiles live under ~/.openhands/llm-profiles by default. We create two # variants that share the same usage_id so they can be swapped at runtime. registry = LLMRegistry() usage_id = "support-agent" @@ -123,29 +120,3 @@ reloaded.run() print("Reloaded run finished with profile:", reloaded.state.agent.llm.profile_id) - -# --------------------------------------------------------------------------- -# Part 2: Inline persistence rejects runtime switching -# --------------------------------------------------------------------------- -# When OPENHANDS_INLINE_CONVERSATIONS is true the conversation persists full -# LLM payloads instead of profile references. Switching profiles would break -# the diff reconciliation step, so the SDK deliberately rejects it with a -# RuntimeError. We demonstrate that behaviour below. -os.environ["OPENHANDS_INLINE_CONVERSATIONS"] = "true" - -inline_persistence_dir = Path("./.conversations_switch_demo_inline").resolve() -inline_agent = Agent(llm=registry.load_profile(base_profile_id), tools=[]) -inline_conversation = Conversation( - agent=inline_agent, - workspace=str(workspace_dir), - persistence_dir=str(inline_persistence_dir), - conversation_id=uuid.uuid4(), - visualizer=None, -) - -try: - inline_conversation.switch_llm(alt_profile_id) -except RuntimeError as exc: - print("Inline mode switch attempt rejected as expected:", exc) -else: - raise AssertionError("Inline mode should have rejected the LLM switch") diff --git a/examples/02_remote_agent_server/07_llm_switch_and_restore.py b/examples/02_remote_agent_server/07_llm_switch_and_restore.py new file mode 100644 index 0000000000..1fa3015067 --- /dev/null +++ b/examples/02_remote_agent_server/07_llm_switch_and_restore.py @@ -0,0 +1,183 @@ +"""Demonstrate agent-server LLM switching + persistence across restart. + +This script: +1) Starts a local Python agent-server with a dedicated conversations directory. +2) Creates a conversation (without running it). +3) Switches the conversation's active LLM via `POST /api/conversations/{id}/llm`. +4) Restarts the agent-server and verifies the switched LLM persists on restore. + +The switch uses an inline LLM payload, which is the recommended path for remote +clients whose "profiles" are local-only (e.g. the VS Code extension). +""" + +from __future__ import annotations + +import json +import os +import socket +import subprocess +import sys +import time +from pathlib import Path + +import httpx + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + sock.listen(1) + return int(sock.getsockname()[1]) + + +def _wait_for_health(base_url: str, timeout_seconds: float = 30.0) -> None: + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + response = httpx.get(f"{base_url}/health", timeout=1.0) + if response.status_code == 200: + return + except Exception: + pass + time.sleep(0.25) + raise RuntimeError( + f"Timed out waiting for agent-server health at {base_url}/health" + ) + + +def _start_agent_server( + *, conversations_path: Path +) -> tuple[subprocess.Popen[str], str]: + port = _find_free_port() + base_url = f"http://127.0.0.1:{port}" + + env = { + **os.environ, + "PYTHONUNBUFFERED": "1", + "OH_ENABLE_VSCODE": "0", + "OH_ENABLE_VNC": "0", + "OH_PRELOAD_TOOLS": "0", + "SESSION_API_KEY": "", + "OH_CONVERSATIONS_PATH": str(conversations_path), + } + + proc = subprocess.Popen( + [ + sys.executable, + "-m", + "openhands.agent_server", + "--host", + "127.0.0.1", + "--port", + str(port), + ], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + try: + _wait_for_health(base_url) + except Exception: + try: + output = (proc.stdout.read() if proc.stdout else "") or "" + except Exception: + output = "" + proc.terminate() + raise RuntimeError(f"agent-server failed to start.\n\n{output}") from None + + return proc, base_url + + +def _stop_agent_server(proc: subprocess.Popen[str]) -> None: + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait(timeout=5) + + +def main() -> None: + root = Path(".agent_server_llm_switch_demo").resolve() + conversations_path = root / "conversations" + workspace_path = root / "workspace" + conversations_path.mkdir(parents=True, exist_ok=True) + workspace_path.mkdir(parents=True, exist_ok=True) + + proc_1, base_1 = _start_agent_server(conversations_path=conversations_path) + conversation_id: str + try: + print("agent-server #1:", base_1) + + create = httpx.post( + f"{base_1}/api/conversations", + json={ + "agent": { + "llm": { + "usage_id": "agent", + "model": "test-provider/original", + "api_key": "test-key", + }, + "tools": [], + }, + "workspace": {"working_dir": str(workspace_path)}, + }, + timeout=10.0, + ) + create.raise_for_status() + conversation_id = create.json()["id"] + print("conversation id:", conversation_id) + + update = httpx.post( + f"{base_1}/api/conversations/{conversation_id}/llm", + json={ + "llm": { + "usage_id": "ignored-by-server", + "model": "test-provider/alternate", + "api_key": "test-key-2", + } + }, + timeout=10.0, + ) + update.raise_for_status() + + info = httpx.get( + f"{base_1}/api/conversations/{conversation_id}", + timeout=10.0, + ) + info.raise_for_status() + current_model = info.json()["agent"]["llm"]["model"] + print("server #1 model:", current_model) + if current_model != "test-provider/alternate": + raise RuntimeError("LLM switch did not apply on server #1") + finally: + _stop_agent_server(proc_1) + + proc_2, base_2 = _start_agent_server(conversations_path=conversations_path) + try: + print("agent-server #2:", base_2) + restored = httpx.get( + f"{base_2}/api/conversations/{conversation_id}", + timeout=10.0, + ) + restored.raise_for_status() + restored_model = restored.json()["agent"]["llm"]["model"] + print("server #2 restored model:", restored_model) + if restored_model != "test-provider/alternate": + raise RuntimeError("LLM switch did not persist across restart") + finally: + _stop_agent_server(proc_2) + + print("✓ LLM switch persisted across agent-server restart") + + base_state = ( + conversations_path / conversation_id.replace("-", "") / "base_state.json" + ) + if base_state.exists(): + payload = json.loads(base_state.read_text(encoding="utf-8")) + print("base_state.json agent.llm:", payload.get("agent", {}).get("llm")) + + +if __name__ == "__main__": + main() diff --git a/openhands-agent-server/openhands/agent_server/conversation_router.py b/openhands-agent-server/openhands/agent_server/conversation_router.py index 0da9457cc4..e50bff16aa 100644 --- a/openhands-agent-server/openhands/agent_server/conversation_router.py +++ b/openhands-agent-server/openhands/agent_server/conversation_router.py @@ -21,6 +21,8 @@ SetSecurityAnalyzerRequest, StartConversationRequest, Success, + SwitchLLMProfileRequest, + UpdateConversationLLMRequest, UpdateConversationRequest, UpdateSecretsRequest, ) @@ -254,6 +256,44 @@ async def set_conversation_security_analyzer( return Success() +@conversation_router.post( + "/{conversation_id}/llm/switch", + responses={404: {"description": "Item not found"}}, +) +async def switch_conversation_llm( + conversation_id: UUID, + request: SwitchLLMProfileRequest, + conversation_service: ConversationService = Depends(get_conversation_service), +) -> Success: + """Switch the conversation's active agent LLM profile for future requests.""" + event_service = await conversation_service.get_event_service(conversation_id) + if event_service is None: + raise HTTPException(status.HTTP_404_NOT_FOUND) + await event_service.switch_llm(request.profile_id) + return Success() + + +@conversation_router.post( + "/{conversation_id}/llm", + responses={404: {"description": "Item not found"}}, +) +async def update_conversation_llm( + conversation_id: UUID, + request: UpdateConversationLLMRequest, + conversation_service: ConversationService = Depends(get_conversation_service), +) -> Success: + """Update the conversation's active agent LLM for future requests.""" + event_service = await conversation_service.get_event_service(conversation_id) + if event_service is None: + raise HTTPException(status.HTTP_404_NOT_FOUND) + if request.profile_id is not None: + await event_service.switch_llm(request.profile_id) + else: + assert request.llm is not None + await event_service.set_llm(request.llm) + return Success() + + @conversation_router.patch( "/{conversation_id}", responses={404: {"description": "Item not found"}} ) diff --git a/openhands-agent-server/openhands/agent_server/event_service.py b/openhands-agent-server/openhands/agent_server/event_service.py index bbd5672993..85370f9d51 100644 --- a/openhands-agent-server/openhands/agent_server/event_service.py +++ b/openhands-agent-server/openhands/agent_server/event_service.py @@ -543,6 +543,32 @@ async def set_security_analyzer( None, self._conversation.set_security_analyzer, security_analyzer ) + async def _update_llm(self, update_fn, *args) -> None: + """Apply an LLM update and re-wire telemetry for future completions.""" + conversation = self._conversation + if conversation is None: + raise ValueError("inactive_service") + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, update_fn, *args) + # The agent may now hold a new LLM instance; re-wire telemetry callbacks so + # clients continue receiving logs/stats for future completions. + self._setup_llm_log_streaming(conversation.agent) + self._setup_stats_streaming(conversation.agent) + + async def switch_llm(self, profile_id: str) -> None: + """Switch the conversation's active agent LLM to the given profile.""" + conversation = self._conversation + if conversation is None: + raise ValueError("inactive_service") + await self._update_llm(conversation.switch_llm, profile_id) + + async def set_llm(self, llm: LLM) -> None: + """Replace the conversation's active agent LLM instance.""" + conversation = self._conversation + if conversation is None: + raise ValueError("inactive_service") + await self._update_llm(conversation.set_llm, llm) + async def close(self): await self._pub_sub.close() if self._conversation: diff --git a/openhands-agent-server/openhands/agent_server/models.py b/openhands-agent-server/openhands/agent_server/models.py index aeef0cab9f..b5e7c0c0b8 100644 --- a/openhands-agent-server/openhands/agent_server/models.py +++ b/openhands-agent-server/openhands/agent_server/models.py @@ -4,7 +4,7 @@ from typing import Any, Literal from uuid import uuid4 -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from openhands.agent_server.utils import OpenHandsUUID, utc_now from openhands.sdk import LLM, AgentBase, Event, ImageContent, Message, TextContent @@ -107,6 +107,47 @@ class StartConversationRequest(BaseModel): ), ) + @model_validator(mode="before") + @classmethod + def _expand_agent_llm_profile_reference(cls, data: Any): + """Expand profile_id-only LLM payloads for server-side execution. + + FastAPI request parsing does not provide Pydantic validation context, so + we expand `{profile_id: ...}` here before the SDK's `LLM` model validates. + """ + if not isinstance(data, dict): + return data + + agent = data.get("agent") + if not isinstance(agent, dict): + return data + + llm = agent.get("llm") + if not isinstance(llm, dict): + return data + + profile_id = llm.get("profile_id") + if not profile_id or "model" in llm: + return data + + from openhands.sdk.llm import LLMRegistry + + registry = LLMRegistry() + loaded = registry.load_profile(str(profile_id)) + usage_id = llm.get("usage_id") or "agent" + + expanded = loaded.model_dump( + exclude_none=True, context={"expose_secrets": True} + ) + expanded["profile_id"] = str(profile_id) + expanded["usage_id"] = str(usage_id) + + new_agent = dict(agent) + new_agent["llm"] = expanded + new_data = dict(data) + new_data["agent"] = new_agent + return new_data + class StoredConversation(StartConversationRequest): """Stored details about a conversation""" @@ -238,6 +279,39 @@ class GenerateTitleRequest(BaseModel): ) +class SwitchLLMProfileRequest(BaseModel): + """Payload to switch the active agent LLM profile for a conversation.""" + + profile_id: str = Field( + ..., + min_length=1, + description="LLM profile ID to activate for the conversation", + ) + + +class UpdateConversationLLMRequest(BaseModel): + """Payload to update a conversation's active agent LLM. + + Supports either: + - `profile_id`: switch via server-side profile loading + - `llm`: set an inline LLM payload (for clients whose profile schema differs) + """ + + profile_id: str | None = Field( + default=None, + description="Optional LLM profile ID to activate for the conversation", + ) + llm: LLM | None = Field( + default=None, description="Optional inline LLM payload to activate" + ) + + @model_validator(mode="after") + def _validate_one_of(self): + if bool(self.profile_id) == bool(self.llm): + raise ValueError("Exactly one of profile_id or llm must be provided.") + return self + + class GenerateTitleResponse(BaseModel): """Response containing the generated conversation title.""" diff --git a/openhands-sdk/openhands/sdk/agent/base.py b/openhands-sdk/openhands/sdk/agent/base.py index 88b7634641..4bba6a4de0 100644 --- a/openhands-sdk/openhands/sdk/agent/base.py +++ b/openhands-sdk/openhands/sdk/agent/base.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from openhands.sdk.context.agent_context import AgentContext -from openhands.sdk.context.condenser import CondenserBase, LLMSummarizingCondenser +from openhands.sdk.context.condenser import CondenserBase from openhands.sdk.context.prompts.prompt import render_template from openhands.sdk.llm import LLM from openhands.sdk.llm.utils.model_prompt_spec import get_model_prompt_spec @@ -17,7 +17,6 @@ from openhands.sdk.mcp import create_mcp_tools from openhands.sdk.tool import BUILT_IN_TOOLS, Tool, ToolDefinition, resolve_tool from openhands.sdk.utils.models import DiscriminatedUnionMixin -from openhands.sdk.utils.pydantic_diff import pretty_pydantic_diff if TYPE_CHECKING: @@ -40,8 +39,8 @@ class AgentBase(DiscriminatedUnionMixin, ABC): """ model_config = ConfigDict( - frozen=True, arbitrary_types_allowed=True, + validate_assignment=True, ) llm: LLM = Field( @@ -300,72 +299,6 @@ def step( NOTE: state will be mutated in-place. """ - def resolve_diff_from_deserialized(self, persisted: "AgentBase") -> "AgentBase": - """ - Return a new AgentBase instance equivalent to `persisted` but with - explicitly whitelisted fields (e.g. api_key) taken from `self`. - """ - if persisted.__class__ is not self.__class__: - raise ValueError( - f"Cannot resolve from deserialized: persisted agent is of type " - f"{persisted.__class__.__name__}, but self is of type " - f"{self.__class__.__name__}." - ) - - # Get all LLMs from both self and persisted to reconcile them - new_llm = self.llm.resolve_diff_from_deserialized(persisted.llm) - updates: dict[str, Any] = {"llm": new_llm} - - # Reconcile the condenser's LLM if it exists - if self.condenser is not None and persisted.condenser is not None: - # Check if both condensers are LLMSummarizingCondenser - # (which has an llm field) - - if isinstance(self.condenser, LLMSummarizingCondenser) and isinstance( - persisted.condenser, LLMSummarizingCondenser - ): - new_condenser_llm = self.condenser.llm.resolve_diff_from_deserialized( - persisted.condenser.llm - ) - new_condenser = persisted.condenser.model_copy( - update={"llm": new_condenser_llm} - ) - updates["condenser"] = new_condenser - - # Reconcile agent_context - always use the current environment's agent_context - # This allows resuming conversations from different directories and handles - # cases where skills, working directory, or other context has changed - if self.agent_context is not None: - updates["agent_context"] = self.agent_context - - # Create maps by tool name for easy lookup - runtime_tools_map = {tool.name: tool for tool in self.tools} - persisted_tools_map = {tool.name: tool for tool in persisted.tools} - - # Check that tool names match - runtime_names = set(runtime_tools_map.keys()) - persisted_names = set(persisted_tools_map.keys()) - - if runtime_names != persisted_names: - missing_in_runtime = persisted_names - runtime_names - missing_in_persisted = runtime_names - persisted_names - error_msg = "Tools don't match between runtime and persisted agents." - if missing_in_runtime: - error_msg += f" Missing in runtime: {missing_in_runtime}." - if missing_in_persisted: - error_msg += f" Missing in persisted: {missing_in_persisted}." - raise ValueError(error_msg) - - reconciled = persisted.model_copy(update=updates) - if self.model_dump(exclude_none=True) != reconciled.model_dump( - exclude_none=True - ): - raise ValueError( - "The Agent provided is different from the one in persisted state.\n" - f"Diff: {pretty_pydantic_diff(self, reconciled)}" - ) - return reconciled - def _clone_with_llm(self, llm: LLM) -> "AgentBase": """Return a copy of this agent with ``llm`` swapped in.""" diff --git a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py index 2a4324b4d0..64f4c57ca3 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py @@ -2,6 +2,7 @@ import uuid from collections.abc import Mapping from pathlib import Path +from typing import cast from openhands.sdk.agent.base import AgentBase from openhands.sdk.context.prompts.prompt import render_template @@ -139,10 +140,23 @@ def __init__( stuck_detection=stuck_detection, llm_registry=self.llm_registry, ) + # After restore, the persisted base_state is the source of truth for the + # active agent/workspace configuration. + self.agent = self._state.agent + self.workspace = cast(LocalWorkspace, self._state.workspace) + ws_path = Path(self.workspace.working_dir) + if not ws_path.exists(): + ws_path.mkdir(parents=True, exist_ok=True) # Default callback: persist every event to state def _default_callback(e): self._state.events.append(e) + # Flush the base snapshot so nested updates (e.g. usage metrics) + # are persisted alongside the event stream. + try: + self._state.persist_base_state() + except Exception: + logger.exception("Failed to persist base state after event append") self._hook_processor = None hook_callback = None @@ -262,6 +276,16 @@ def switch_llm(self, profile_id: str) -> None: profile_id, ) + def set_llm(self, llm: LLM) -> None: + """Replace the active agent LLM instance for future requests.""" + + with self._state: + self._state.set_agent_llm(llm, registry=self.llm_registry) + self.agent = self._state.agent + logger.info( + "Updated conversation %s LLM (usage_id=%s)", self._state.id, llm.usage_id + ) + @observe(name="conversation.send_message") def send_message(self, message: str | Message, sender: str | None = None) -> None: """Send a message to the agent. @@ -526,7 +550,8 @@ def close(self) -> None: try: tools_map = self.agent.tools_map except (AttributeError, RuntimeError): - # Agent not initialized or partially constructed + # Conversation may be partially constructed (e.g. validation failure) + # and the agent may not have been initialized. return for tool in tools_map.values(): try: diff --git a/openhands-sdk/openhands/sdk/conversation/state.py b/openhands-sdk/openhands/sdk/conversation/state.py index 5392bb738a..3114736974 100644 --- a/openhands-sdk/openhands/sdk/conversation/state.py +++ b/openhands-sdk/openhands/sdk/conversation/state.py @@ -24,7 +24,7 @@ from openhands.sdk.llm.llm_registry import LLMRegistry -from openhands.sdk.persistence import INLINE_CONTEXT_KEY, should_inline_conversations +from openhands.sdk.persistence import INLINE_CONTEXT_KEY from openhands.sdk.security.analyzer import SecurityAnalyzerBase from openhands.sdk.security.confirmation_policy import ( ConfirmationPolicyBase, @@ -36,6 +36,13 @@ logger = get_logger(__name__) +_RUNTIME_LLM_OVERLAY_FIELDS: tuple[str, ...] = ( + "api_key", + "aws_access_key_id", + "aws_secret_access_key", + "aws_region_name", +) + class ConversationExecutionStatus(str, Enum): """Enum representing the current execution state of the conversation.""" @@ -166,15 +173,28 @@ def _save_base_state(self, fs: FileStore) -> None: """ Persist base state snapshot (no events; events are file-backed). """ - inline_mode = should_inline_conversations() - # Pass the inline preference down so LLM serialization knows whether to - # inline credentials or persist a profile reference. payload = self.model_dump_json( exclude_none=True, - context={INLINE_CONTEXT_KEY: inline_mode}, + # Persist conversations using profile references when available, and + # include secrets so inline LLM configurations can restore without + # external reconciliation. + context={INLINE_CONTEXT_KEY: False, "expose_secrets": True}, ) fs.write(BASE_STATE, payload) + def persist_base_state(self) -> None: + """Persist the latest base_state snapshot if persistence is enabled. + + Note: This is intentionally explicit so nested mutations (e.g. stats / + metrics updates) can be flushed to disk even when `__setattr__` is not + triggered. + """ + + fs = getattr(self, "_fs", None) + if fs is None: + return + self._save_base_state(fs) + # ===== Factory: open-or-create (no load/save methods needed) ===== @classmethod def create( @@ -188,9 +208,10 @@ def create( llm_registry: "LLMRegistry | None" = None, ) -> "ConversationState": """ - If base_state.json exists: resume (attach EventLog, - reconcile agent, enforce id). - Else: create fresh (agent required), persist base, and return. + If base_state.json exists: resume (attach EventLog, validate id, and apply + restore-time configuration overrides). + + Else: create fresh, persist base snapshot, and return. Args: llm_registry: Optional registry used to expand profile references when @@ -207,22 +228,20 @@ def create( except FileNotFoundError: base_text = None - inline_mode = should_inline_conversations() - # Keep validation and serialization in sync when loading previously - # persisted state. - context: dict[str, object] = {INLINE_CONTEXT_KEY: inline_mode} - if not inline_mode: - registry = llm_registry - if registry is None: - from openhands.sdk.llm.llm_registry import LLMRegistry + registry = llm_registry + if registry is None: + from openhands.sdk.llm.llm_registry import LLMRegistry - registry = LLMRegistry() - context["llm_registry"] = registry + registry = LLMRegistry() + # Provide the registry so LLM profile references can be expanded while + # materialising persisted state. + context: dict[str, object] = {"llm_registry": registry} # ---- Resume path ---- if base_text: base_payload = json.loads(base_text) state = cls.model_validate(base_payload, context=context) + persisted_agent = state.agent # Enforce conversation id match if state.id != id: @@ -231,16 +250,34 @@ def create( f"but persisted state has {state.id}" ) - # Reconcile agent config with deserialized one - resolved = agent.resolve_diff_from_deserialized(state.agent) - - # Attach runtime handles and commit reconciled agent (may autosave) + # Attach runtime handles. state._fs = file_store state._events = EventLog(file_store, dir_path=EVENTS_DIR) state._autosave_enabled = True - state.agent = resolved - state.stats = ConversationStats() + # Restore-time configuration: + # - The persisted state owns the *active* LLM by default (so a switched + # profile survives restarts). + # - The caller-provided agent (tools/context/etc) may override this. + # To explicitly override the persisted LLM on restore, pass an agent + # whose `llm.profile_id` is set. + if agent is not None: + effective_llm = ( + agent.llm + if agent.llm.profile_id is not None + else persisted_agent.llm + ) + # Always prefer secrets from the runtime agent, even when the + # persisted LLM selection is retained. + secret_updates: dict[str, object] = {} + for field in _RUNTIME_LLM_OVERLAY_FIELDS: + value = getattr(agent.llm, field) + if value is not None: + secret_updates[field] = value + if secret_updates: + effective_llm = effective_llm.model_copy(update=secret_updates) + state.agent = agent._clone_with_llm(effective_llm) + state.workspace = workspace logger.info( f"Resumed conversation {state.id} from persistent storage.\n" @@ -263,11 +300,8 @@ def create( max_iterations=max_iterations, stuck_detection=stuck_detection, ) - # Record existing analyzer configuration in state - state.security_analyzer = state.security_analyzer state._fs = file_store state._events = EventLog(file_store, dir_path=EVENTS_DIR) - state.stats = ConversationStats() state._save_base_state(file_store) # initial snapshot state._autosave_enabled = True @@ -350,16 +384,8 @@ def pop_blocked_message(self, message_id: str) -> str | None: def switch_agent_llm(self, profile_id: str, *, registry: "LLMRegistry") -> None: """Swap the agent's primary LLM to ``profile_id`` using ``registry``.""" - if should_inline_conversations(): - raise RuntimeError( - "LLM switching requires OPENHANDS_INLINE_CONVERSATIONS to be false." - ) - - if self.execution_status not in ( - ConversationExecutionStatus.IDLE, - ConversationExecutionStatus.FINISHED, - ): - raise RuntimeError("Agent must be idle before switching LLM profiles.") + if self.execution_status == ConversationExecutionStatus.RUNNING: + raise RuntimeError("Agent must be idle to switch LLM profiles.") usage_id = self.agent.llm.usage_id try: @@ -373,6 +399,21 @@ def switch_agent_llm(self, profile_id: str, *, registry: "LLMRegistry") -> None: if self.execution_status == ConversationExecutionStatus.FINISHED: self.execution_status = ConversationExecutionStatus.IDLE + def set_agent_llm(self, llm: "Any", *, registry: "LLMRegistry") -> None: + """Replace the agent's primary LLM instance for future requests. + + This supports remote clients that want to switch LLMs by sending an inline + LLM payload (as opposed to a server-side profile_id reference). + """ + if self.execution_status == ConversationExecutionStatus.RUNNING: + raise RuntimeError("Agent must be idle to switch LLMs.") + + usage_id = self.agent.llm.usage_id + new_llm = registry.set(usage_id, llm) + self.agent = self.agent._clone_with_llm(new_llm) + if self.execution_status == ConversationExecutionStatus.FINISHED: + self.execution_status = ConversationExecutionStatus.IDLE + @staticmethod def get_unmatched_actions(events: Sequence[Event]) -> list[ActionEvent]: """Find actions in the event history that don't have matching observations. diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 6be73d5b72..8fdce085fd 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -32,8 +32,6 @@ if TYPE_CHECKING: # type hints only, avoid runtime import cycle from openhands.sdk.tool.tool import ToolDefinition -from openhands.sdk.utils.pydantic_diff import pretty_pydantic_diff - with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -84,7 +82,7 @@ from openhands.sdk.llm.utils.retry_mixin import RetryMixin from openhands.sdk.llm.utils.telemetry import Telemetry from openhands.sdk.logger import ENV_LOG_DIR, get_logger -from openhands.sdk.persistence import INLINE_CONTEXT_KEY, should_inline_conversations +from openhands.sdk.persistence import INLINE_CONTEXT_KEY logger = get_logger(__name__) @@ -332,18 +330,6 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): exclude=True, ) _metrics: Metrics | None = PrivateAttr(default=None) - # ===== Plain class vars (NOT Fields) ===== - # When serializing, these fields (SecretStr) will be dump to "****" - # When deserializing, these fields will be ignored and we will override - # them from the LLM instance provided at runtime. - OVERRIDE_ON_SERIALIZE: tuple[str, ...] = ( - "api_key", - "aws_access_key_id", - "aws_secret_access_key", - # Dynamic runtime metadata for telemetry/routing that can differ across sessions - # and should not cause resume-time diffs. Always prefer the runtime value. - "litellm_extra_body", - ) # Runtime-only private attrs _model_info: Any = PrivateAttr(default=None) @@ -396,23 +382,9 @@ def _coerce_inputs(cls, data: Any, info: ValidationInfo): profile_id = d.get("profile_id") if profile_id and "model" not in d: - inline_pref = None - if info.context is not None and INLINE_CONTEXT_KEY in info.context: - inline_pref = info.context[INLINE_CONTEXT_KEY] - if inline_pref is None: - inline_pref = should_inline_conversations() - - if inline_pref: - raise ValueError( - "Encountered profile reference for LLM while " - "OPENHANDS_INLINE_CONVERSATIONS is enabled. " - "Inline the profile or set " - "OPENHANDS_INLINE_CONVERSATIONS=false." - ) - if info.context is None or "llm_registry" not in info.context: raise ValueError( - "LLM registry required in context to load profile references." + "LLM registry required in context to load LLM profile references." ) registry = info.context["llm_registry"] @@ -1163,39 +1135,3 @@ def _cast_value(raw: str, t: Any) -> Any: if v is not None: data[field_name] = v return cls(**data) - - def resolve_diff_from_deserialized(self, persisted: LLM) -> LLM: - """Resolve differences between a deserialized LLM and the current instance. - - This is due to fields like api_key being serialized to "****" in dumps, - and we want to ensure that when loading from a file, we still use the - runtime-provided api_key in the self instance. - - Return a new LLM instance equivalent to `persisted` but with - explicitly whitelisted fields (e.g. api_key) taken from `self`. - """ - if persisted.__class__ is not self.__class__: - raise ValueError( - f"Cannot resolve_diff_from_deserialized between {self.__class__} " - f"and {persisted.__class__}" - ) - - # Copy allowed fields from runtime llm into the persisted llm - llm_updates = {} - persisted_dump = persisted.model_dump(context={"expose_secrets": True}) - for field in self.OVERRIDE_ON_SERIALIZE: - if field in persisted_dump.keys(): - llm_updates[field] = getattr(self, field) - if llm_updates: - reconciled = persisted.model_copy(update=llm_updates) - else: - reconciled = persisted - - dump = self.model_dump(context={"expose_secrets": True}) - reconciled_dump = reconciled.model_dump(context={"expose_secrets": True}) - if dump != reconciled_dump: - raise ValueError( - "The LLM provided is different from the one in persisted state.\n" - f"Diff: {pretty_pydantic_diff(self, reconciled)}" - ) - return reconciled diff --git a/openhands-sdk/openhands/sdk/llm/llm_registry.py b/openhands-sdk/openhands/sdk/llm/llm_registry.py index a1cf6d4c85..cf83a6bbf1 100644 --- a/openhands-sdk/openhands/sdk/llm/llm_registry.py +++ b/openhands-sdk/openhands/sdk/llm/llm_registry.py @@ -1,4 +1,5 @@ import json +import os import re from collections.abc import Callable, Mapping from pathlib import Path @@ -20,6 +21,7 @@ "aws_secret_access_key", ) _DEFAULT_PROFILE_DIR = Path.home() / ".openhands" / "llm-profiles" +_PROFILE_DIR_ENV_VAR = "OPENHANDS_LLM_PROFILES_DIR" _PROFILE_ID_PATTERN = re.compile(r"^[A-Za-z0-9._-]+$") @@ -103,6 +105,18 @@ def add(self, llm: LLM) -> None: f"[LLM registry {self.registry_id}]: Added LLM for usage {usage_id}" ) + def set(self, usage_id: str, llm: LLM) -> LLM: + """Upsert the LLM for ``usage_id`` and return the stored instance.""" + + updated = llm.model_copy(update={"usage_id": usage_id}) + self._usage_to_llm[usage_id] = updated + self.notify(RegistryEvent(llm=updated)) + logger.info( + f"[LLM registry {self.registry_id}]: Set LLM for usage {usage_id} " + f"(model={updated.model})" + ) + return updated + def _ensure_safe_profile_id(self, profile_id: str) -> str: if not profile_id or profile_id in {".", ".."}: raise ValueError("Invalid profile ID.") @@ -136,7 +150,7 @@ def switch_profile(self, usage_id: str, profile_id: str) -> LLM: current_llm = self._usage_to_llm[usage_id] safe_id = self._ensure_safe_profile_id(profile_id) - if getattr(current_llm, "profile_id", None) == safe_id: + if current_llm.profile_id == safe_id: return current_llm llm = self.load_profile(safe_id) @@ -212,6 +226,9 @@ def validate_profile(self, data: Mapping[str, Any]) -> tuple[bool, list[str]]: def _resolve_profile_dir(self, profile_dir: str | Path | None) -> Path: if profile_dir is not None: return Path(profile_dir).expanduser() + env = os.getenv(_PROFILE_DIR_ENV_VAR) + if env: + return Path(env).expanduser() return _DEFAULT_PROFILE_DIR def _load_profile_with_synced_id(self, path: Path, profile_id: str) -> LLM: @@ -219,9 +236,8 @@ def _load_profile_with_synced_id(self, path: Path, profile_id: str) -> LLM: Most callers expect the loaded LLM to reflect the profile file name so the client apps can surface the active profile (e.g., in conversation history or CLI - prompts). We construct a *new* ``LLM`` via :meth:`model_copy` instead of - mutating the loaded instance to respect the SDK's immutability - conventions. + prompts). We keep the runtime instance aligned with the filename so UIs can + accurately display the active profile ID. We always align ``profile_id`` with the filename so callers get a precise view of which profile is active without mutating the on-disk payload. This diff --git a/openhands-sdk/openhands/sdk/persistence/__init__.py b/openhands-sdk/openhands/sdk/persistence/__init__.py index fb8eb9066c..8fb9ac6604 100644 --- a/openhands-sdk/openhands/sdk/persistence/__init__.py +++ b/openhands-sdk/openhands/sdk/persistence/__init__.py @@ -13,11 +13,9 @@ API. """ -from .settings import INLINE_CONTEXT_KEY, INLINE_ENV_VAR, should_inline_conversations +from .settings import INLINE_CONTEXT_KEY __all__ = [ "INLINE_CONTEXT_KEY", - "INLINE_ENV_VAR", - "should_inline_conversations", ] diff --git a/openhands-sdk/openhands/sdk/persistence/settings.py b/openhands-sdk/openhands/sdk/persistence/settings.py index bd8f51e490..c2029d838a 100644 --- a/openhands-sdk/openhands/sdk/persistence/settings.py +++ b/openhands-sdk/openhands/sdk/persistence/settings.py @@ -1,17 +1,11 @@ -"""Shared helpers for SDK persistence configuration.""" +"""Shared helpers for SDK persistence configuration. -from __future__ import annotations +This module intentionally avoids environment-driven behavior for conversation +serialization. Persistence should be deterministic and controlled by the caller +via explicit serialization context. +""" -import os +from __future__ import annotations -INLINE_ENV_VAR = "OPENHANDS_INLINE_CONVERSATIONS" INLINE_CONTEXT_KEY = "inline_llm_persistence" -_FALSE_VALUES = {"0", "false", "no"} - - -def should_inline_conversations() -> bool: - """Return True when conversations should be persisted with inline LLM payloads.""" - - value = os.getenv(INLINE_ENV_VAR, "true").strip().lower() - return value not in _FALSE_VALUES diff --git a/tests/agent_server/test_agent_server_wsproto.py b/tests/agent_server/test_agent_server_wsproto.py index 3e0d8044f3..e2974b1a3f 100644 --- a/tests/agent_server/test_agent_server_wsproto.py +++ b/tests/agent_server/test_agent_server_wsproto.py @@ -20,8 +20,15 @@ def find_free_port(): return s.getsockname()[1] -def run_agent_server(port, api_key): +def run_agent_server(port, api_key, conversations_path=None, llm_profiles_dir=None): os.environ["OH_SESSION_API_KEYS"] = f'["{api_key}"]' + os.environ["OH_ENABLE_VSCODE"] = "0" + os.environ["OH_ENABLE_VNC"] = "0" + os.environ["OH_PRELOAD_TOOLS"] = "0" + if conversations_path is not None: + os.environ["OH_CONVERSATIONS_PATH"] = str(conversations_path) + if llm_profiles_dir is not None: + os.environ["OPENHANDS_LLM_PROFILES_DIR"] = str(llm_profiles_dir) sys.argv = ["agent-server", "--port", str(port)] from openhands.agent_server.__main__ import main @@ -104,3 +111,353 @@ async def test_agent_server_websocket_with_wsproto(agent_server): await ws.send( json.dumps({"role": "user", "content": "Hello from wsproto test"}) ) + + +def _wait_for_server(port: int) -> None: + for _ in range(30): + try: + response = requests.get(f"http://127.0.0.1:{port}/docs", timeout=1) + if response.status_code == 200: + return + except requests.exceptions.ConnectionError: + pass + time.sleep(1) + raise RuntimeError(f"Agent server failed to start on port {port}") + + +def test_agent_server_llm_switch_persists_across_restart(tmp_path): + api_key = "test-llm-switch-key" + conversations_path = tmp_path / "conversations" + llm_profiles_dir = tmp_path / "llm-profiles" + llm_profiles_dir.mkdir(parents=True, exist_ok=True) + + # Profile usage_id must not collide with the conversation's usage_id. + (llm_profiles_dir / "alternate.json").write_text( + json.dumps({"model": "test-provider/alternate", "usage_id": "profile-slot"}), + encoding="utf-8", + ) + + port_1 = find_free_port() + process_1 = multiprocessing.Process( + target=run_agent_server, + args=(port_1, api_key, str(conversations_path), str(llm_profiles_dir)), + ) + process_1.start() + try: + _wait_for_server(port_1) + base_1 = f"http://127.0.0.1:{port_1}" + + response = requests.post( + f"{base_1}/api/conversations", + headers={"X-Session-API-Key": api_key}, + json={ + "agent": { + "llm": { + "usage_id": "test-llm", + "model": "test-provider/test-model", + "api_key": "test-key", + }, + "tools": [], + }, + "workspace": {"working_dir": str(tmp_path / "workspace")}, + }, + timeout=10, + ) + assert response.status_code in [200, 201] + conversation_id = response.json()["id"] + + switch = requests.post( + f"{base_1}/api/conversations/{conversation_id}/llm/switch", + headers={"X-Session-API-Key": api_key}, + json={"profile_id": "alternate"}, + timeout=10, + ) + assert switch.status_code == 200 + + info = requests.get( + f"{base_1}/api/conversations/{conversation_id}", + headers={"X-Session-API-Key": api_key}, + timeout=10, + ) + assert info.status_code == 200 + assert info.json()["agent"]["llm"]["profile_id"] == "alternate" + finally: + process_1.terminate() + process_1.join(timeout=5) + if process_1.is_alive(): + process_1.kill() + process_1.join() + + port_2 = find_free_port() + process_2 = multiprocessing.Process( + target=run_agent_server, + args=(port_2, api_key, str(conversations_path), str(llm_profiles_dir)), + ) + process_2.start() + try: + _wait_for_server(port_2) + base_2 = f"http://127.0.0.1:{port_2}" + + restored = requests.get( + f"{base_2}/api/conversations/{conversation_id}", + headers={"X-Session-API-Key": api_key}, + timeout=10, + ) + assert restored.status_code == 200 + assert restored.json()["agent"]["llm"]["profile_id"] == "alternate" + finally: + process_2.terminate() + process_2.join(timeout=5) + if process_2.is_alive(): + process_2.kill() + process_2.join() + + +def test_agent_server_set_llm_persists_across_restart(tmp_path): + api_key = "test-llm-set-key" + conversations_path = tmp_path / "conversations" + + port_1 = find_free_port() + process_1 = multiprocessing.Process( + target=run_agent_server, args=(port_1, api_key, str(conversations_path), None) + ) + process_1.start() + try: + _wait_for_server(port_1) + base_1 = f"http://127.0.0.1:{port_1}" + + response = requests.post( + f"{base_1}/api/conversations", + headers={"X-Session-API-Key": api_key}, + json={ + "agent": { + "llm": { + "usage_id": "test-llm", + "model": "test-provider/test-model", + "api_key": "test-key", + }, + "tools": [], + }, + "workspace": {"working_dir": str(tmp_path / "workspace")}, + }, + timeout=10, + ) + assert response.status_code in [200, 201] + conversation_id = response.json()["id"] + + update = requests.post( + f"{base_1}/api/conversations/{conversation_id}/llm", + headers={"X-Session-API-Key": api_key}, + json={ + "llm": { + "usage_id": "ignored-by-server", + "model": "test-provider/alternate", + "api_key": "test-key-2", + } + }, + timeout=10, + ) + assert update.status_code == 200 + + info = requests.get( + f"{base_1}/api/conversations/{conversation_id}", + headers={"X-Session-API-Key": api_key}, + timeout=10, + ) + assert info.status_code == 200 + assert info.json()["agent"]["llm"]["model"] == "test-provider/alternate" + finally: + process_1.terminate() + process_1.join(timeout=5) + if process_1.is_alive(): + process_1.kill() + process_1.join() + + port_2 = find_free_port() + process_2 = multiprocessing.Process( + target=run_agent_server, args=(port_2, api_key, str(conversations_path), None) + ) + process_2.start() + try: + _wait_for_server(port_2) + base_2 = f"http://127.0.0.1:{port_2}" + + restored = requests.get( + f"{base_2}/api/conversations/{conversation_id}", + headers={"X-Session-API-Key": api_key}, + timeout=10, + ) + assert restored.status_code == 200 + assert restored.json()["agent"]["llm"]["model"] == "test-provider/alternate" + finally: + process_2.terminate() + process_2.join(timeout=5) + if process_2.is_alive(): + process_2.kill() + process_2.join() + + +def test_agent_server_large_event_log_restore_and_runtime_llm_switch(tmp_path): + """End-to-end regression: large event history + restore + runtime LLM switching. + + This covers the remote-client path (VS Code / agent-sdk-ts) that: + - restores a conversation with a non-trivial number of events (pagination) + - switches the active LLM at runtime (while idle) + - switches again after server restart (restored conversation) + """ + + api_key = "test-llm-large-history-key" + conversations_path = tmp_path / "conversations" + llm_profiles_dir = tmp_path / "llm-profiles" + llm_profiles_dir.mkdir(parents=True, exist_ok=True) + + # Profile usage_id must not collide with the conversation's usage_id. + (llm_profiles_dir / "alternate.json").write_text( + json.dumps({"model": "test-provider/alternate", "usage_id": "profile-slot-a"}), + encoding="utf-8", + ) + (llm_profiles_dir / "second.json").write_text( + json.dumps({"model": "test-provider/second", "usage_id": "profile-slot-b"}), + encoding="utf-8", + ) + + def _post_event( + base: str, conversation_id: str, session: requests.Session, idx: int + ): + return session.post( + f"{base}/api/conversations/{conversation_id}/events", + json={ + "role": "user", + "content": [{"type": "text", "text": f"E2E history event {idx}"}], + "run": False, + }, + timeout=10, + ) + + event_count = 150 + + port_1 = find_free_port() + process_1 = multiprocessing.Process( + target=run_agent_server, + args=(port_1, api_key, str(conversations_path), str(llm_profiles_dir)), + ) + process_1.start() + try: + _wait_for_server(port_1) + base_1 = f"http://127.0.0.1:{port_1}" + + session = requests.Session() + session.headers.update({"X-Session-API-Key": api_key}) + + response = session.post( + f"{base_1}/api/conversations", + json={ + "agent": { + "llm": { + "usage_id": "test-llm", + "model": "test-provider/test-model", + "api_key": "test-key", + }, + "tools": [], + }, + "workspace": {"working_dir": str(tmp_path / "workspace")}, + }, + timeout=10, + ) + assert response.status_code in [200, 201] + conversation_id = response.json()["id"] + + for idx in range(event_count): + posted = _post_event(base_1, conversation_id, session, idx) + assert posted.status_code == 200 + + # Validate pagination works with a "realistic" event count. + count = session.get( + f"{base_1}/api/conversations/{conversation_id}/events/count", + timeout=10, + ) + assert count.status_code == 200 + assert int(count.text) >= event_count + + page_1 = session.get( + f"{base_1}/api/conversations/{conversation_id}/events/search", + params={"limit": 100}, + timeout=10, + ) + assert page_1.status_code == 200 + page_1_payload = page_1.json() + assert len(page_1_payload.get("items", [])) == 100 + assert page_1_payload.get("next_page_id") + + # Runtime switch (idle). + switch_1 = session.post( + f"{base_1}/api/conversations/{conversation_id}/llm/switch", + json={"profile_id": "alternate"}, + timeout=10, + ) + assert switch_1.status_code == 200 + + info_1 = session.get( + f"{base_1}/api/conversations/{conversation_id}", + timeout=10, + ) + assert info_1.status_code == 200 + assert info_1.json()["agent"]["llm"]["profile_id"] == "alternate" + + # Ensure we can keep appending events after the switch. + posted_after = _post_event(base_1, conversation_id, session, event_count + 1) + assert posted_after.status_code == 200 + finally: + process_1.terminate() + process_1.join(timeout=5) + if process_1.is_alive(): + process_1.kill() + process_1.join() + + port_2 = find_free_port() + process_2 = multiprocessing.Process( + target=run_agent_server, + args=(port_2, api_key, str(conversations_path), str(llm_profiles_dir)), + ) + process_2.start() + try: + _wait_for_server(port_2) + base_2 = f"http://127.0.0.1:{port_2}" + + session = requests.Session() + session.headers.update({"X-Session-API-Key": api_key}) + + restored = session.get( + f"{base_2}/api/conversations/{conversation_id}", + timeout=10, + ) + assert restored.status_code == 200 + assert restored.json()["agent"]["llm"]["profile_id"] == "alternate" + + restored_count = session.get( + f"{base_2}/api/conversations/{conversation_id}/events/count", + timeout=10, + ) + assert restored_count.status_code == 200 + assert int(restored_count.text) >= event_count + + # Runtime switch on restored conversation. + switch_2 = session.post( + f"{base_2}/api/conversations/{conversation_id}/llm/switch", + json={"profile_id": "second"}, + timeout=10, + ) + assert switch_2.status_code == 200 + + after_switch_2 = session.get( + f"{base_2}/api/conversations/{conversation_id}", + timeout=10, + ) + assert after_switch_2.status_code == 200 + assert after_switch_2.json()["agent"]["llm"]["profile_id"] == "second" + finally: + process_2.terminate() + process_2.join(timeout=5) + if process_2.is_alive(): + process_2.kill() + process_2.join() diff --git a/tests/agent_server/test_conversation_router.py b/tests/agent_server/test_conversation_router.py index 2c747214c7..d1e43ab509 100644 --- a/tests/agent_server/test_conversation_router.py +++ b/tests/agent_server/test_conversation_router.py @@ -134,6 +134,144 @@ def test_search_conversations_default_params( client.app.dependency_overrides.clear() +def test_switch_conversation_llm_success( + client, mock_conversation_service, mock_event_service, sample_conversation_id +): + mock_conversation_service.get_event_service.return_value = mock_event_service + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + try: + response = client.post( + f"/api/conversations/{sample_conversation_id}/llm/switch", + json={"profile_id": "test-profile"}, + ) + assert response.status_code == 200 + assert response.json() == {"success": True} + mock_event_service.switch_llm.assert_awaited_once_with("test-profile") + finally: + client.app.dependency_overrides.clear() + + +def test_update_conversation_llm_profile_id_success( + client, mock_conversation_service, mock_event_service, sample_conversation_id +): + mock_conversation_service.get_event_service.return_value = mock_event_service + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + try: + response = client.post( + f"/api/conversations/{sample_conversation_id}/llm", + json={"profile_id": "test-profile"}, + ) + assert response.status_code == 200 + assert response.json() == {"success": True} + mock_event_service.switch_llm.assert_awaited_once_with("test-profile") + mock_event_service.set_llm.assert_not_awaited() + finally: + client.app.dependency_overrides.clear() + + +def test_update_conversation_llm_inline_payload_success( + client, mock_conversation_service, mock_event_service, sample_conversation_id +): + mock_conversation_service.get_event_service.return_value = mock_event_service + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + try: + response = client.post( + f"/api/conversations/{sample_conversation_id}/llm", + json={ + "llm": { + "usage_id": "agent", + "model": "test-provider/test-model", + "api_key": "test-key", + } + }, + ) + assert response.status_code == 200 + assert response.json() == {"success": True} + mock_event_service.switch_llm.assert_not_awaited() + assert mock_event_service.set_llm.await_count == 1 + + (called_llm,) = mock_event_service.set_llm.await_args.args + assert isinstance(called_llm, LLM) + assert called_llm.usage_id == "agent" + assert called_llm.model == "test-provider/test-model" + finally: + client.app.dependency_overrides.clear() + + +@pytest.mark.parametrize( + "payload", + [ + {}, + {"profile_id": "test-profile", "llm": {"usage_id": "agent", "model": "x"}}, + ], +) +def test_update_conversation_llm_requires_exactly_one_field( + client, + mock_conversation_service, + mock_event_service, + sample_conversation_id, + payload, +): + mock_conversation_service.get_event_service.return_value = mock_event_service + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + try: + response = client.post( + f"/api/conversations/{sample_conversation_id}/llm", + json=payload, + ) + assert response.status_code == 422 + finally: + client.app.dependency_overrides.clear() + + +def test_update_conversation_llm_not_found( + client, mock_conversation_service, sample_conversation_id +): + mock_conversation_service.get_event_service.return_value = None + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + try: + response = client.post( + f"/api/conversations/{sample_conversation_id}/llm", + json={"profile_id": "test-profile"}, + ) + assert response.status_code == 404 + finally: + client.app.dependency_overrides.clear() + + +def test_switch_conversation_llm_not_found( + client, mock_conversation_service, sample_conversation_id +): + mock_conversation_service.get_event_service.return_value = None + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + try: + response = client.post( + f"/api/conversations/{sample_conversation_id}/llm/switch", + json={"profile_id": "test-profile"}, + ) + assert response.status_code == 404 + finally: + client.app.dependency_overrides.clear() + + def test_search_conversations_with_all_params( client, mock_conversation_service, sample_conversation_info ): @@ -485,6 +623,7 @@ def test_start_conversation_new( # Create request data with proper serialization request_data = { "agent": { + "kind": "Agent", "llm": { "model": "gpt-4o", "api_key": "test-key", @@ -531,6 +670,7 @@ def test_start_conversation_existing( # Create request data with proper serialization request_data = { "agent": { + "kind": "Agent", "llm": { "model": "gpt-4o", "api_key": "test-key", @@ -590,6 +730,7 @@ def test_start_conversation_minimal_request( # Create minimal valid request minimal_request = { "agent": { + "kind": "Agent", "llm": { "model": "gpt-4o", "api_key": "test-key", diff --git a/tests/cross/test_agent_reconciliation.py b/tests/cross/test_agent_reconciliation.py index 7864ca729a..05d1cc9030 100644 --- a/tests/cross/test_agent_reconciliation.py +++ b/tests/cross/test_agent_reconciliation.py @@ -7,7 +7,6 @@ from pydantic import SecretStr from openhands.sdk import Agent -from openhands.sdk.agent import AgentBase from openhands.sdk.context import AgentContext, Skill from openhands.sdk.context.condenser.llm_summarizing_condenser import ( LLMSummarizingCondenser, @@ -107,10 +106,8 @@ def test_conversation_restarted_with_changed_working_directory(tmp_path_factory) # Tests from test_local_conversation_tools_integration.py -def test_conversation_with_different_agent_tools_fails(): - """Test that using an agent with different tools fails (tools must match).""" - import pytest - +def test_conversation_with_different_agent_tools_succeeds(): + """Conversation restart should allow swapping the agent's tool set.""" with tempfile.TemporaryDirectory() as temp_dir: # Create and save conversation with original agent original_tools = [ @@ -146,17 +143,18 @@ def test_conversation_with_different_agent_tools_fails(): ) different_agent = Agent(llm=llm2, tools=different_tools) - # This should fail - tools must match during reconciliation - with pytest.raises( - ValueError, match="Tools don't match between runtime and persisted agents" - ): - LocalConversation( - agent=different_agent, - workspace=temp_dir, - persistence_dir=temp_dir, - conversation_id=conversation_id, # Use same ID to avoid ID mismatch - visualizer=None, - ) + # Restart should succeed, adopting the runtime agent's tools. + restarted = LocalConversation( + agent=different_agent, + workspace=temp_dir, + persistence_dir=temp_dir, + conversation_id=conversation_id, # Use same ID to avoid ID mismatch + visualizer=None, + ) + assert len(restarted.agent.tools) == 1 + assert restarted.agent.tools[0].name == "TerminalTool" + assert "terminal" in restarted.agent.tools_map + assert "file_editor" not in restarted.agent.tools_map def test_conversation_with_same_agent_succeeds(): @@ -212,59 +210,6 @@ def test_conversation_with_same_agent_succeeds(): assert len(new_conversation.state.events) > 0 -def test_agent_resolve_diff_from_deserialized(): - """Test agent's resolve_diff_from_deserialized method. - - Includes tolerance for litellm_extra_body differences injected at CLI load time. - """ - with tempfile.TemporaryDirectory(): - # Create original agent - tools = [Tool(name="TerminalTool")] - llm = LLM( - model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm" - ) - original_agent = Agent(llm=llm, tools=tools) - - # Serialize and deserialize to simulate persistence - serialized = original_agent.model_dump_json() - deserialized_agent = AgentBase.model_validate_json(serialized) - - # Create runtime agent with same configuration - llm2 = LLM( - model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm" - ) - runtime_agent = Agent(llm=llm2, tools=tools) - - # Should resolve successfully - resolved = runtime_agent.resolve_diff_from_deserialized(deserialized_agent) - # Test model_dump equality - assert resolved.model_dump(mode="json") == runtime_agent.model_dump(mode="json") - assert resolved.llm.model == runtime_agent.llm.model - assert resolved.__class__ == runtime_agent.__class__ - - # Now simulate CLI injecting dynamic litellm_extra_body metadata at load time - injected = deserialized_agent.model_copy( - update={ - "llm": deserialized_agent.llm.model_copy( - update={ - "litellm_extra_body": { - "metadata": { - "session_id": "sess-123", - "tags": ["app:openhands", "model:gpt-4o-mini"], - "trace_version": "1.2.3", - } - } - } - ) - } - ) - - # Reconcile again: differences in litellm_extra_body should be allowed and - # the runtime value should be preferred without raising an error. - resolved2 = runtime_agent.resolve_diff_from_deserialized(injected) - assert resolved2.llm.litellm_extra_body == runtime_agent.llm.litellm_extra_body - - @patch("openhands.sdk.llm.llm.litellm_completion") def test_conversation_persistence_lifecycle(mock_completion): """Test full conversation persistence lifecycle similar to examples/10_persistence.py.""" # noqa: E501 diff --git a/tests/sdk/agent/test_agent_immutability.py b/tests/sdk/agent/test_agent_immutability.py index 961f4077a3..917531d943 100644 --- a/tests/sdk/agent/test_agent_immutability.py +++ b/tests/sdk/agent/test_agent_immutability.py @@ -1,4 +1,4 @@ -"""Tests for Agent immutability and statelessness.""" +"""Tests for Agent component swapping and statelessness.""" import pytest from pydantic import SecretStr, ValidationError @@ -7,8 +7,8 @@ from openhands.sdk.llm import LLM -class TestAgentImmutability: - """Test Agent immutability and statelessness.""" +class TestAgentComponentSwaps: + """Test Agent component swapping and statelessness.""" def setup_method(self): """Set up test environment.""" @@ -16,22 +16,30 @@ def setup_method(self): model="gpt-4", api_key=SecretStr("test-key"), usage_id="test-llm" ) - def test_agent_is_frozen(self): - """Test that Agent instances are frozen (immutable).""" + def test_agent_allows_component_swaps(self): + """Agent should support swapping components via cloning.""" agent = Agent(llm=self.llm, tools=[]) - # Test that we cannot modify core fields after creation - with pytest.raises(ValidationError, match="Instance is frozen"): - agent.llm = "new_value" # type: ignore[assignment] + new_llm = LLM( + model="gpt-4", + api_key=SecretStr("new-key"), + usage_id="test-llm", + ) + swapped = agent._clone_with_llm(new_llm) + assert swapped.llm == new_llm + assert agent.llm == self.llm - with pytest.raises(ValidationError, match="Instance is frozen"): - agent.agent_context = None + with pytest.raises(ValidationError): + Agent.model_validate({"llm": "new_value", "tools": []}) - # Verify the agent remains functional after failed modification attempts - assert agent.llm == self.llm + # Verify the agent remains functional after modification attempts assert isinstance(agent.system_message, str) assert len(agent.system_message) > 0 + without_context = agent.model_copy(update={"agent_context": None}) + assert isinstance(without_context.system_message, str) + assert len(without_context.system_message) > 0 + def test_system_message_is_computed_property(self): """Test that system_message is computed on-demand, not stored.""" agent = Agent(llm=self.llm, tools=[]) diff --git a/tests/sdk/conversation/local/test_state_serialization.py b/tests/sdk/conversation/local/test_state_serialization.py index 3e515c6824..d9e8889389 100644 --- a/tests/sdk/conversation/local/test_state_serialization.py +++ b/tests/sdk/conversation/local/test_state_serialization.py @@ -9,16 +9,11 @@ from pydantic import SecretStr from openhands.sdk import Agent, Conversation -from openhands.sdk.agent.base import AgentBase from openhands.sdk.conversation.impl.local_conversation import LocalConversation from openhands.sdk.conversation.state import ( ConversationExecutionStatus, ConversationState, ) -from openhands.sdk.conversation.types import ( - ConversationCallbackType, - ConversationTokenCallbackType, -) from openhands.sdk.event.llm_convertible import MessageEvent, SystemPromptEvent from openhands.sdk.llm import LLM, Message, TextContent from openhands.sdk.llm.llm_registry import LLMRegistry, RegistryEvent @@ -140,8 +135,8 @@ def test_conversation_state_persistence_save_load(): assert loaded_dump == original_dump if original_stats is not None: assert loaded_stats is not None - loaded_metrics = loaded_stats.get("service_to_metrics", {}) - for key, metric in original_stats.get("service_to_metrics", {}).items(): + loaded_metrics = loaded_stats.get("usage_to_metrics", {}) + for key, metric in original_stats.get("usage_to_metrics", {}).items(): assert key in loaded_metrics assert loaded_metrics[key] == metric # Also verify key fields are preserved @@ -150,11 +145,10 @@ def test_conversation_state_persistence_save_load(): def test_conversation_state_profile_reference_mode(tmp_path, monkeypatch): - """When inline persistence is disabled we store profile references.""" + """Conversation persistence stores LLM profile references.""" home_dir = tmp_path / "home" monkeypatch.setenv("HOME", str(home_dir)) - monkeypatch.setenv("OPENHANDS_INLINE_CONVERSATIONS", "false") registry = LLMRegistry() llm = LLM(model="litellm_proxy/openai/gpt-5-mini", usage_id="agent") @@ -187,45 +181,6 @@ def test_conversation_state_profile_reference_mode(tmp_path, monkeypatch): assert loaded_state.agent.llm.model == llm.model -def test_conversation_state_inline_mode_errors_on_profile_reference( - tmp_path, monkeypatch -): - """Inline mode raises when encountering a persisted profile reference.""" - - home_dir = tmp_path / "home" - monkeypatch.setenv("HOME", str(home_dir)) - monkeypatch.setenv("OPENHANDS_INLINE_CONVERSATIONS", "false") - - registry = LLMRegistry() - llm = LLM(model="litellm_proxy/openai/gpt-5-mini", usage_id="agent") - registry.save_profile("profile-inline", llm) - agent = Agent(llm=registry.load_profile("profile-inline"), tools=[]) - - conv_id = uuid.UUID("12345678-1234-5678-9abc-1234567890aa") - persistence_root = tmp_path / "conv" - persistence_dir = LocalConversation.get_persistence_dir(persistence_root, conv_id) - - ConversationState.create( - workspace=LocalWorkspace(working_dir="/tmp"), - persistence_dir=persistence_dir, - agent=agent, - id=conv_id, - ) - - # Switch env back to inline mode and expect a failure on reload - monkeypatch.setenv("OPENHANDS_INLINE_CONVERSATIONS", "true") - - with pytest.raises(ValueError) as exc: - Conversation( - agent=agent, - persistence_dir=persistence_root, - workspace=LocalWorkspace(working_dir="/tmp"), - conversation_id=conv_id, - ) - - assert "OPENHANDS_INLINE_CONVERSATIONS" in str(exc.value) - - def test_conversation_state_incremental_save(): """Test that ConversationState saves events incrementally.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -286,8 +241,8 @@ def test_conversation_state_incremental_save(): assert loaded_dump == original_dump if original_stats is not None: assert loaded_stats is not None - loaded_metrics = loaded_stats.get("service_to_metrics", {}) - for key, metric in original_stats.get("service_to_metrics", {}).items(): + loaded_metrics = loaded_stats.get("usage_to_metrics", {}) + for key, metric in original_stats.get("usage_to_metrics", {}).items(): assert key in loaded_metrics assert loaded_metrics[key] == metric @@ -523,37 +478,6 @@ def test_conversation_state_thread_safety(): assert not state.owned() -def test_agent_resolve_diff_different_class_raises_error(): - """Test that resolve_diff_from_deserialized raises error for different agent classes.""" # noqa: E501 - - class DifferentAgent(AgentBase): - def __init__(self): - llm = LLM( - model="gpt-4o-mini", - api_key=SecretStr("test-key"), - usage_id="test-llm", - ) - super().__init__(llm=llm, tools=[]) - - def init_state(self, state, on_event): - pass - - def step( - self, - conversation, - on_event: ConversationCallbackType, - on_token: ConversationTokenCallbackType | None = None, - ): - pass - - llm = LLM(model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm") - original_agent = Agent(llm=llm, tools=[]) - different_agent = DifferentAgent() - - with pytest.raises(ValueError, match="Cannot resolve from deserialized"): - original_agent.resolve_diff_from_deserialized(different_agent) - - def test_conversation_state_flags_persistence(): """Test that conversation state flags are properly persisted.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -594,9 +518,6 @@ def test_conversation_state_flags_persistence(): assert loaded_state.execution_status == ConversationExecutionStatus.FINISHED assert loaded_state.confirmation_policy == AlwaysConfirm() assert loaded_state.activated_knowledge_skills == ["agent1", "agent2"] - # Test model_dump equality - assert loaded_state.model_dump(mode="json") != state.model_dump(mode="json") - loaded_state.stats.register_llm(RegistryEvent(llm=llm)) assert loaded_state.model_dump(mode="json") == state.model_dump(mode="json") @@ -658,7 +579,6 @@ def test_local_conversation_switch_llm_persists_profile(tmp_path, monkeypatch): home_dir = tmp_path / "home" home_dir.mkdir() monkeypatch.setenv("HOME", str(home_dir)) - monkeypatch.setenv("OPENHANDS_INLINE_CONVERSATIONS", "false") registry = LLMRegistry() base_llm = LLM(model="gpt-4o-mini", usage_id="test-llm") @@ -703,35 +623,10 @@ def test_local_conversation_switch_llm_persists_profile(tmp_path, monkeypatch): assert reloaded.state.agent.llm.profile_id == "alt" -def test_local_conversation_switch_llm_inline_mode_rejected(tmp_path, monkeypatch): - home_dir = tmp_path / "home" - home_dir.mkdir() - monkeypatch.setenv("HOME", str(home_dir)) - monkeypatch.setenv("OPENHANDS_INLINE_CONVERSATIONS", "true") - - registry = LLMRegistry() - base_llm = LLM(model="gpt-4o-mini", usage_id="test-llm") - registry.save_profile("base", base_llm) - registry.save_profile("alt", LLM(model="gpt-4o", usage_id="alternate")) - - agent = Agent(llm=registry.load_profile("base"), tools=[]) - conversation = Conversation( - agent=agent, - workspace=str(tmp_path / "workspace"), - persistence_dir=str(tmp_path / "persist"), - visualizer=None, - ) - assert isinstance(conversation, LocalConversation) - - with pytest.raises(RuntimeError, match="OPENHANDS_INLINE_CONVERSATIONS"): - conversation.switch_llm("alt") - - def test_local_conversation_switch_llm_requires_idle(tmp_path, monkeypatch): home_dir = tmp_path / "home" home_dir.mkdir() monkeypatch.setenv("HOME", str(home_dir)) - monkeypatch.setenv("OPENHANDS_INLINE_CONVERSATIONS", "false") registry = LLMRegistry() base_llm = LLM(model="gpt-4o-mini", usage_id="test-llm") @@ -758,7 +653,6 @@ def test_local_conversation_switch_llm_missing_profile_rejected(tmp_path, monkey home_dir = tmp_path / "home" home_dir.mkdir() monkeypatch.setenv("HOME", str(home_dir)) - monkeypatch.setenv("OPENHANDS_INLINE_CONVERSATIONS", "false") registry = LLMRegistry() base_llm = LLM(model="gpt-4o-mini", usage_id="test-llm") diff --git a/tests/sdk/llm/test_llm_reconciliation.py b/tests/sdk/llm/test_llm_reconciliation.py deleted file mode 100644 index 2cd86d75e2..0000000000 --- a/tests/sdk/llm/test_llm_reconciliation.py +++ /dev/null @@ -1,42 +0,0 @@ -from pydantic import SecretStr - -from openhands.sdk import Agent -from openhands.sdk.agent import AgentBase -from openhands.sdk.llm import LLM - - -def test_resolve_diff_ignores_litellm_extra_body_diffs(): - tools = [] - llm = LLM(model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm") - original_agent = Agent(llm=llm, tools=tools) - - serialized = original_agent.model_dump_json() - deserialized_agent = AgentBase.model_validate_json(serialized) - - runtime_agent = Agent( - llm=LLM( - model="gpt-4o-mini", - api_key=SecretStr("test-key"), - usage_id="test-llm", - ), - tools=tools, - ) - - injected = deserialized_agent.model_copy( - update={ - "llm": deserialized_agent.llm.model_copy( - update={ - "litellm_extra_body": { - "metadata": { - "session_id": "sess-xyz", - "tags": ["app:openhands", "model:gpt-4o-mini"], - "trace_version": "9.9.9", - } - } - } - ) - } - ) - - resolved = runtime_agent.resolve_diff_from_deserialized(injected) - assert resolved.llm.litellm_extra_body == runtime_agent.llm.litellm_extra_body diff --git a/tests/sdk/llm/test_llm_registry_profiles.py b/tests/sdk/llm/test_llm_registry_profiles.py index fa404455d8..45553cfc7b 100644 --- a/tests/sdk/llm/test_llm_registry_profiles.py +++ b/tests/sdk/llm/test_llm_registry_profiles.py @@ -98,7 +98,6 @@ def test_llm_serializer_respects_inline_context(): def test_llm_validator_loads_profile_reference(tmp_path, monkeypatch): - monkeypatch.setenv("OPENHANDS_INLINE_CONVERSATIONS", "false") registry = LLMRegistry(profile_dir=tmp_path) source_llm = LLM(model="gpt-4o-mini", usage_id="service") registry.save_profile("profile-tests", source_llm)