diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index f5961472..90a05563 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -24,11 +24,13 @@ PermissionHandler, PermissionRequest, PermissionRequestResult, + PingResponse, ProviderConfig, ResumeSessionConfig, SessionConfig, SessionEvent, SessionMetadata, + StopError, Tool, ToolHandler, ToolInvocation, @@ -56,11 +58,13 @@ "PermissionHandler", "PermissionRequest", "PermissionRequestResult", + "PingResponse", "ProviderConfig", "ResumeSessionConfig", "SessionConfig", "SessionEvent", "SessionMetadata", + "StopError", "Tool", "ToolHandler", "ToolInvocation", diff --git a/python/copilot/client.py b/python/copilot/client.py index 6870bda4..522a2f2b 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -19,7 +19,7 @@ import subprocess import threading from dataclasses import asdict, is_dataclass -from typing import Any, Optional, cast +from typing import Any, Optional from .generated.session_events import session_event_from_dict from .jsonrpc import JsonRpcClient @@ -32,10 +32,12 @@ GetAuthStatusResponse, GetStatusResponse, ModelInfo, + PingResponse, ProviderConfig, ResumeSessionConfig, SessionConfig, SessionMetadata, + StopError, ToolHandler, ToolInvocation, ToolResult, @@ -220,7 +222,7 @@ async def start(self) -> None: self._state = "error" raise - async def stop(self) -> list[dict[str, str]]: + async def stop(self) -> list["StopError"]: """ Stop the CLI server and close all active sessions. @@ -230,16 +232,16 @@ async def stop(self) -> list[dict[str, str]]: 3. Terminates the CLI server process (if spawned by this client) Returns: - A list of errors that occurred during cleanup, each as a dict with - a 'message' key. An empty list indicates all cleanup succeeded. + A list of StopError objects containing error messages that occurred + during cleanup. An empty list indicates all cleanup succeeded. Example: >>> errors = await client.stop() >>> if errors: ... for error in errors: - ... print(f"Cleanup error: {error['message']}") + ... print(f"Cleanup error: {error.message}") """ - errors: list[dict[str, str]] = [] + errors: list[StopError] = [] # Atomically take ownership of all sessions and clear the dict # so no other thread can access them @@ -251,7 +253,9 @@ async def stop(self) -> list[dict[str, str]]: try: await session.destroy() except Exception as e: - errors.append({"message": f"Failed to destroy session {session.session_id}: {e}"}) + errors.append( + StopError(message=f"Failed to destroy session {session.session_id}: {e}") + ) # Close client if self._client: @@ -570,7 +574,7 @@ def get_state(self) -> ConnectionState: """ return self._state - async def ping(self, message: Optional[str] = None) -> dict: + async def ping(self, message: Optional[str] = None) -> "PingResponse": """ Send a ping request to the server to verify connectivity. @@ -578,59 +582,61 @@ async def ping(self, message: Optional[str] = None) -> dict: message: Optional message to include in the ping. Returns: - A dict containing the ping response with 'message', 'timestamp', - and 'protocolVersion' keys. + A PingResponse object containing the ping response. Raises: RuntimeError: If the client is not connected. Example: >>> response = await client.ping("health check") - >>> print(f"Server responded at {response['timestamp']}") + >>> print(f"Server responded at {response.timestamp}") """ if not self._client: raise RuntimeError("Client not connected") - return await self._client.request("ping", {"message": message}) + result = await self._client.request("ping", {"message": message}) + return PingResponse.from_dict(result) async def get_status(self) -> "GetStatusResponse": """ Get CLI status including version and protocol information. Returns: - A GetStatusResponse containing version and protocolVersion. + A GetStatusResponse object containing version and protocolVersion. Raises: RuntimeError: If the client is not connected. Example: >>> status = await client.get_status() - >>> print(f"CLI version: {status['version']}") + >>> print(f"CLI version: {status.version}") """ if not self._client: raise RuntimeError("Client not connected") - return await self._client.request("status.get", {}) + result = await self._client.request("status.get", {}) + return GetStatusResponse.from_dict(result) async def get_auth_status(self) -> "GetAuthStatusResponse": """ Get current authentication status. Returns: - A GetAuthStatusResponse containing authentication state. + A GetAuthStatusResponse object containing authentication state. Raises: RuntimeError: If the client is not connected. Example: >>> auth = await client.get_auth_status() - >>> if auth['isAuthenticated']: - ... print(f"Logged in as {auth.get('login')}") + >>> if auth.isAuthenticated: + ... print(f"Logged in as {auth.login}") """ if not self._client: raise RuntimeError("Client not connected") - return await self._client.request("auth.getStatus", {}) + result = await self._client.request("auth.getStatus", {}) + return GetAuthStatusResponse.from_dict(result) async def list_models(self) -> list["ModelInfo"]: """ @@ -646,13 +652,14 @@ async def list_models(self) -> list["ModelInfo"]: Example: >>> models = await client.list_models() >>> for model in models: - ... print(f"{model['id']}: {model['name']}") + ... print(f"{model.id}: {model.name}") """ if not self._client: raise RuntimeError("Client not connected") response = await self._client.request("models.list", {}) - return response.get("models", []) + models_data = response.get("models", []) + return [ModelInfo.from_dict(model) for model in models_data] async def list_sessions(self) -> list["SessionMetadata"]: """ @@ -661,9 +668,7 @@ async def list_sessions(self) -> list["SessionMetadata"]: Returns metadata about each session including ID, timestamps, and summary. Returns: - A list of session metadata dictionaries with keys: sessionId (str), - startTime (str), modifiedTime (str), summary (str, optional), - and isRemote (bool). + A list of SessionMetadata objects. Raises: RuntimeError: If the client is not connected. @@ -671,13 +676,14 @@ async def list_sessions(self) -> list["SessionMetadata"]: Example: >>> sessions = await client.list_sessions() >>> for session in sessions: - ... print(f"Session: {session['sessionId']}") + ... print(f"Session: {session.sessionId}") """ if not self._client: raise RuntimeError("Client not connected") response = await self._client.request("session.list", {}) - return response.get("sessions", []) + sessions_data = response.get("sessions", []) + return [SessionMetadata.from_dict(session) for session in sessions_data] async def delete_session(self, session_id: str) -> None: """ @@ -714,7 +720,7 @@ async def _verify_protocol_version(self) -> None: """Verify that the server's protocol version matches the SDK's expected version.""" expected_version = get_sdk_protocol_version() ping_result = await self.ping() - server_version = ping_result.get("protocolVersion") + server_version = ping_result.protocolVersion if server_version is None: raise RuntimeError( @@ -845,11 +851,11 @@ async def read_port(): if not process or not process.stdout: raise RuntimeError("Process not started or stdout not available") while True: - line = cast(bytes, await loop.run_in_executor(None, process.stdout.readline)) + line = await loop.run_in_executor(None, process.stdout.readline) if not line: raise RuntimeError("CLI process exited before announcing port") - line_str = line.decode() + line_str = line.decode() if isinstance(line, bytes) else line match = re.search(r"listening on port (\d+)", line_str, re.IGNORECASE) if match: self._actual_port = int(match.group(1)) diff --git a/python/copilot/tools.py b/python/copilot/tools.py index d9757820..43c1ed99 100644 --- a/python/copilot/tools.py +++ b/python/copilot/tools.py @@ -186,7 +186,7 @@ def _normalize_result(result: Any) -> ToolResult: # ToolResult passes through directly if isinstance(result, dict) and "resultType" in result and "textResultForLlm" in result: - return result # type: ignore + return result # Strings pass through directly if isinstance(result, str): diff --git a/python/copilot/types.py b/python/copilot/types.py index bb64dd98..14b8e65c 100644 --- a/python/copilot/types.py +++ b/python/copilot/types.py @@ -307,91 +307,363 @@ class MessageOptions(TypedDict): SessionEventHandler = Callable[[SessionEvent], None] +# Response from ping +@dataclass +class PingResponse: + """Response from ping""" + + message: str # Echo message with "pong: " prefix + timestamp: int # Server timestamp in milliseconds + protocolVersion: int # Protocol version for SDK compatibility + + @staticmethod + def from_dict(obj: Any) -> PingResponse: + assert isinstance(obj, dict) + message = obj.get("message") + timestamp = obj.get("timestamp") + protocolVersion = obj.get("protocolVersion") + if message is None or timestamp is None or protocolVersion is None: + raise ValueError( + f"Missing required fields in PingResponse: message={message}, " + f"timestamp={timestamp}, protocolVersion={protocolVersion}" + ) + return PingResponse(str(message), int(timestamp), int(protocolVersion)) + + def to_dict(self) -> dict: + result: dict = {} + result["message"] = self.message + result["timestamp"] = self.timestamp + result["protocolVersion"] = self.protocolVersion + return result + + +# Error information from client stop +@dataclass +class StopError: + """Error information from client stop""" + + message: str # Error message describing what failed during cleanup + + @staticmethod + def from_dict(obj: Any) -> StopError: + assert isinstance(obj, dict) + message = obj.get("message") + if message is None: + raise ValueError("Missing required field 'message' in StopError") + return StopError(str(message)) + + def to_dict(self) -> dict: + result: dict = {} + result["message"] = self.message + return result + + # Response from status.get -class GetStatusResponse(TypedDict): +@dataclass +class GetStatusResponse: """Response from status.get""" version: str # Package version (e.g., "1.0.0") protocolVersion: int # Protocol version for SDK compatibility + @staticmethod + def from_dict(obj: Any) -> GetStatusResponse: + assert isinstance(obj, dict) + version = obj.get("version") + protocolVersion = obj.get("protocolVersion") + if version is None or protocolVersion is None: + raise ValueError( + f"Missing required fields in GetStatusResponse: version={version}, " + f"protocolVersion={protocolVersion}" + ) + return GetStatusResponse(str(version), int(protocolVersion)) + + def to_dict(self) -> dict: + result: dict = {} + result["version"] = self.version + result["protocolVersion"] = self.protocolVersion + return result + # Response from auth.getStatus -class GetAuthStatusResponse(TypedDict): +@dataclass +class GetAuthStatusResponse: """Response from auth.getStatus""" isAuthenticated: bool # Whether the user is authenticated - authType: NotRequired[ - Literal["user", "env", "gh-cli", "hmac", "api-key", "token"] - ] # Authentication type - host: NotRequired[str] # GitHub host URL - login: NotRequired[str] # User login name - statusMessage: NotRequired[str] # Human-readable status message + authType: str | None = None # Authentication type + host: str | None = None # GitHub host URL + login: str | None = None # User login name + statusMessage: str | None = None # Human-readable status message + + @staticmethod + def from_dict(obj: Any) -> GetAuthStatusResponse: + assert isinstance(obj, dict) + isAuthenticated = obj.get("isAuthenticated") + if isAuthenticated is None: + raise ValueError("Missing required field 'isAuthenticated' in GetAuthStatusResponse") + authType = obj.get("authType") + host = obj.get("host") + login = obj.get("login") + statusMessage = obj.get("statusMessage") + return GetAuthStatusResponse( + isAuthenticated=bool(isAuthenticated), + authType=authType, + host=host, + login=login, + statusMessage=statusMessage, + ) + + def to_dict(self) -> dict: + result: dict = {} + result["isAuthenticated"] = self.isAuthenticated + if self.authType is not None: + result["authType"] = self.authType + if self.host is not None: + result["host"] = self.host + if self.login is not None: + result["login"] = self.login + if self.statusMessage is not None: + result["statusMessage"] = self.statusMessage + return result # Model capabilities -class ModelVisionLimits(TypedDict, total=False): +@dataclass +class ModelVisionLimits: """Vision-specific limits""" - supported_media_types: list[str] - max_prompt_images: int - max_prompt_image_size: int + supported_media_types: list[str] | None = None + max_prompt_images: int | None = None + max_prompt_image_size: int | None = None + + @staticmethod + def from_dict(obj: Any) -> ModelVisionLimits: + assert isinstance(obj, dict) + supported_media_types = obj.get("supported_media_types") + max_prompt_images = obj.get("max_prompt_images") + max_prompt_image_size = obj.get("max_prompt_image_size") + return ModelVisionLimits( + supported_media_types=supported_media_types, + max_prompt_images=max_prompt_images, + max_prompt_image_size=max_prompt_image_size, + ) + + def to_dict(self) -> dict: + result: dict = {} + if self.supported_media_types is not None: + result["supported_media_types"] = self.supported_media_types + if self.max_prompt_images is not None: + result["max_prompt_images"] = self.max_prompt_images + if self.max_prompt_image_size is not None: + result["max_prompt_image_size"] = self.max_prompt_image_size + return result -class ModelLimits(TypedDict, total=False): +@dataclass +class ModelLimits: """Model limits""" - max_prompt_tokens: int - max_context_window_tokens: int - vision: ModelVisionLimits + max_prompt_tokens: int | None = None + max_context_window_tokens: int | None = None + vision: ModelVisionLimits | None = None + + @staticmethod + def from_dict(obj: Any) -> ModelLimits: + assert isinstance(obj, dict) + max_prompt_tokens = obj.get("max_prompt_tokens") + max_context_window_tokens = obj.get("max_context_window_tokens") + vision_dict = obj.get("vision") + vision = ModelVisionLimits.from_dict(vision_dict) if vision_dict else None + return ModelLimits( + max_prompt_tokens=max_prompt_tokens, + max_context_window_tokens=max_context_window_tokens, + vision=vision, + ) + + def to_dict(self) -> dict: + result: dict = {} + if self.max_prompt_tokens is not None: + result["max_prompt_tokens"] = self.max_prompt_tokens + if self.max_context_window_tokens is not None: + result["max_context_window_tokens"] = self.max_context_window_tokens + if self.vision is not None: + result["vision"] = self.vision.to_dict() + return result -class ModelSupports(TypedDict): +@dataclass +class ModelSupports: """Model support flags""" vision: bool + @staticmethod + def from_dict(obj: Any) -> ModelSupports: + assert isinstance(obj, dict) + vision = obj.get("vision") + if vision is None: + raise ValueError("Missing required field 'vision' in ModelSupports") + return ModelSupports(vision=bool(vision)) -class ModelCapabilities(TypedDict): + def to_dict(self) -> dict: + result: dict = {} + result["vision"] = self.vision + return result + + +@dataclass +class ModelCapabilities: """Model capabilities and limits""" supports: ModelSupports limits: ModelLimits + @staticmethod + def from_dict(obj: Any) -> ModelCapabilities: + assert isinstance(obj, dict) + supports_dict = obj.get("supports") + limits_dict = obj.get("limits") + if supports_dict is None or limits_dict is None: + raise ValueError( + f"Missing required fields in ModelCapabilities: supports={supports_dict}, " + f"limits={limits_dict}" + ) + supports = ModelSupports.from_dict(supports_dict) + limits = ModelLimits.from_dict(limits_dict) + return ModelCapabilities(supports=supports, limits=limits) + + def to_dict(self) -> dict: + result: dict = {} + result["supports"] = self.supports.to_dict() + result["limits"] = self.limits.to_dict() + return result + -class ModelPolicy(TypedDict): +@dataclass +class ModelPolicy: """Model policy state""" - state: Literal["enabled", "disabled", "unconfigured"] + state: str # "enabled", "disabled", or "unconfigured" terms: str + @staticmethod + def from_dict(obj: Any) -> ModelPolicy: + assert isinstance(obj, dict) + state = obj.get("state") + terms = obj.get("terms") + if state is None or terms is None: + raise ValueError( + f"Missing required fields in ModelPolicy: state={state}, terms={terms}" + ) + return ModelPolicy(state=str(state), terms=str(terms)) + + def to_dict(self) -> dict: + result: dict = {} + result["state"] = self.state + result["terms"] = self.terms + return result -class ModelBilling(TypedDict): + +@dataclass +class ModelBilling: """Model billing information""" multiplier: float + @staticmethod + def from_dict(obj: Any) -> ModelBilling: + assert isinstance(obj, dict) + multiplier = obj.get("multiplier") + if multiplier is None: + raise ValueError("Missing required field 'multiplier' in ModelBilling") + return ModelBilling(multiplier=float(multiplier)) -class ModelInfo(TypedDict): + def to_dict(self) -> dict: + result: dict = {} + result["multiplier"] = self.multiplier + return result + + +@dataclass +class ModelInfo: """Information about an available model""" id: str # Model identifier (e.g., "claude-sonnet-4.5") name: str # Display name capabilities: ModelCapabilities # Model capabilities and limits - policy: NotRequired[ModelPolicy] # Policy state - billing: NotRequired[ModelBilling] # Billing information - + policy: ModelPolicy | None = None # Policy state + billing: ModelBilling | None = None # Billing information + + @staticmethod + def from_dict(obj: Any) -> ModelInfo: + assert isinstance(obj, dict) + id = obj.get("id") + name = obj.get("name") + capabilities_dict = obj.get("capabilities") + if id is None or name is None or capabilities_dict is None: + raise ValueError( + f"Missing required fields in ModelInfo: id={id}, name={name}, " + f"capabilities={capabilities_dict}" + ) + capabilities = ModelCapabilities.from_dict(capabilities_dict) + policy_dict = obj.get("policy") + policy = ModelPolicy.from_dict(policy_dict) if policy_dict else None + billing_dict = obj.get("billing") + billing = ModelBilling.from_dict(billing_dict) if billing_dict else None + return ModelInfo( + id=str(id), name=str(name), capabilities=capabilities, policy=policy, billing=billing + ) + + def to_dict(self) -> dict: + result: dict = {} + result["id"] = self.id + result["name"] = self.name + result["capabilities"] = self.capabilities.to_dict() + if self.policy is not None: + result["policy"] = self.policy.to_dict() + if self.billing is not None: + result["billing"] = self.billing.to_dict() + return result -class GetModelsResponse(TypedDict): - """Response from models.list""" - models: list[ModelInfo] - - -class SessionMetadata(TypedDict): +@dataclass +class SessionMetadata: """Metadata about a session""" sessionId: str # Session identifier startTime: str # ISO 8601 timestamp when session was created modifiedTime: str # ISO 8601 timestamp when session was last modified - summary: NotRequired[str] # Optional summary of the session isRemote: bool # Whether the session is remote + summary: str | None = None # Optional summary of the session + + @staticmethod + def from_dict(obj: Any) -> SessionMetadata: + assert isinstance(obj, dict) + sessionId = obj.get("sessionId") + startTime = obj.get("startTime") + modifiedTime = obj.get("modifiedTime") + isRemote = obj.get("isRemote") + if sessionId is None or startTime is None or modifiedTime is None or isRemote is None: + raise ValueError( + f"Missing required fields in SessionMetadata: sessionId={sessionId}, " + f"startTime={startTime}, modifiedTime={modifiedTime}, isRemote={isRemote}" + ) + summary = obj.get("summary") + return SessionMetadata( + sessionId=str(sessionId), + startTime=str(startTime), + modifiedTime=str(modifiedTime), + isRemote=bool(isRemote), + summary=summary, + ) + + def to_dict(self) -> dict: + result: dict = {} + result["sessionId"] = self.sessionId + result["startTime"] = self.startTime + result["modifiedTime"] = self.modifiedTime + result["isRemote"] = self.isRemote + if self.summary is not None: + result["summary"] = self.summary + return result diff --git a/python/e2e/test_client.py b/python/e2e/test_client.py index 5cb681ce..720ab416 100644 --- a/python/e2e/test_client.py +++ b/python/e2e/test_client.py @@ -17,8 +17,8 @@ async def test_should_start_and_connect_to_server_using_stdio(self): assert client.get_state() == "connected" pong = await client.ping("test message") - assert pong["message"] == "pong: test message" - assert pong["timestamp"] >= 0 + assert pong.message == "pong: test message" + assert pong.timestamp >= 0 errors = await client.stop() assert len(errors) == 0 @@ -35,8 +35,8 @@ async def test_should_start_and_connect_to_server_using_tcp(self): assert client.get_state() == "connected" pong = await client.ping("test message") - assert pong["message"] == "pong: test message" - assert pong["timestamp"] >= 0 + assert pong.message == "pong: test message" + assert pong.timestamp >= 0 errors = await client.stop() assert len(errors) == 0 @@ -61,7 +61,7 @@ async def test_should_return_errors_on_failed_cleanup(self): errors = await client.stop() assert len(errors) > 0 - assert "Failed to destroy session" in errors[0]["message"] + assert "Failed to destroy session" in errors[0].message finally: await client.force_stop() @@ -81,11 +81,11 @@ async def test_should_get_status_with_version_and_protocol_info(self): await client.start() status = await client.get_status() - assert "version" in status - assert isinstance(status["version"], str) - assert "protocolVersion" in status - assert isinstance(status["protocolVersion"], int) - assert status["protocolVersion"] >= 1 + assert hasattr(status, "version") + assert isinstance(status.version, str) + assert hasattr(status, "protocolVersion") + assert isinstance(status.protocolVersion, int) + assert status.protocolVersion >= 1 await client.stop() finally: @@ -99,11 +99,11 @@ async def test_should_get_auth_status(self): await client.start() auth_status = await client.get_auth_status() - assert "isAuthenticated" in auth_status - assert isinstance(auth_status["isAuthenticated"], bool) - if auth_status["isAuthenticated"]: - assert "authType" in auth_status - assert "statusMessage" in auth_status + assert hasattr(auth_status, "isAuthenticated") + assert isinstance(auth_status.isAuthenticated, bool) + if auth_status.isAuthenticated: + assert hasattr(auth_status, "authType") + assert hasattr(auth_status, "statusMessage") await client.stop() finally: @@ -117,7 +117,7 @@ async def test_should_list_models_when_authenticated(self): await client.start() auth_status = await client.get_auth_status() - if not auth_status["isAuthenticated"]: + if not auth_status.isAuthenticated: # Skip if not authenticated - models.list requires auth await client.stop() return @@ -126,11 +126,11 @@ async def test_should_list_models_when_authenticated(self): assert isinstance(models, list) if len(models) > 0: model = models[0] - assert "id" in model - assert "name" in model - assert "capabilities" in model - assert "supports" in model["capabilities"] - assert "limits" in model["capabilities"] + assert hasattr(model, "id") + assert hasattr(model, "name") + assert hasattr(model, "capabilities") + assert hasattr(model.capabilities, "supports") + assert hasattr(model.capabilities, "limits") await client.stop() finally: diff --git a/python/e2e/test_session.py b/python/e2e/test_session.py index 022548e5..3cd18852 100644 --- a/python/e2e/test_session.py +++ b/python/e2e/test_session.py @@ -196,21 +196,21 @@ async def test_should_list_sessions(self, ctx: E2ETestContext): sessions = await ctx.client.list_sessions() assert isinstance(sessions, list) - session_ids = [s["sessionId"] for s in sessions] + session_ids = [s.sessionId for s in sessions] assert session1.session_id in session_ids assert session2.session_id in session_ids # Verify session metadata structure for session_data in sessions: - assert "sessionId" in session_data - assert "startTime" in session_data - assert "modifiedTime" in session_data - assert "isRemote" in session_data + assert hasattr(session_data, "sessionId") + assert hasattr(session_data, "startTime") + assert hasattr(session_data, "modifiedTime") + assert hasattr(session_data, "isRemote") # summary is optional - assert isinstance(session_data["sessionId"], str) - assert isinstance(session_data["startTime"], str) - assert isinstance(session_data["modifiedTime"], str) - assert isinstance(session_data["isRemote"], bool) + assert isinstance(session_data.sessionId, str) + assert isinstance(session_data.startTime, str) + assert isinstance(session_data.modifiedTime, str) + assert isinstance(session_data.isRemote, bool) async def test_should_delete_session(self, ctx: E2ETestContext): import asyncio @@ -225,7 +225,7 @@ async def test_should_delete_session(self, ctx: E2ETestContext): # Verify session exists in the list sessions = await ctx.client.list_sessions() - session_ids = [s["sessionId"] for s in sessions] + session_ids = [s.sessionId for s in sessions] assert session_id in session_ids # Delete the session @@ -233,7 +233,7 @@ async def test_should_delete_session(self, ctx: E2ETestContext): # Verify session no longer exists in the list sessions_after = await ctx.client.list_sessions() - session_ids_after = [s["sessionId"] for s in sessions_after] + session_ids_after = [s.sessionId for s in sessions_after] assert session_id not in session_ids_after # Verify we cannot resume the deleted session