From bc9104799e22257f7b1f9b79f59e028b453d0756 Mon Sep 17 00:00:00 2001 From: Bartosz Burda Date: Mon, 30 Mar 2026 17:29:40 +0200 Subject: [PATCH 1/4] feat: migrate HTTP client to generated ros2-medkit-client Replace hand-written httpx client (~1078 lines) with a wrapper around the generated ros2-medkit-client package. SovdClient interface preserved - zero changes needed in mcp_app.py tool dispatcher. Key changes: - client.py: SovdClient delegates to MedkitClient via _call() with _ENTITY_FUNC_MAP dispatch for entity-type-polymorphic endpoints - Structured SOVD error codes preserved in SovdClientError messages - asyncio.Lock on lazy client init for concurrent safety - SSRF protection on get_bulk_data_info (was missing vs download) - URL-encoded path segments in raw request fallbacks - Extracted _httpx_client(), _extract_filename(), _validate_relative_uri() Also fixes: - update_execution dispatcher used args.request_data instead of args.update_data (pre-existing bug, AttributeError at runtime) - FaultItem model accepts fault_code alias from generated client - Fault formatting fallbacks check both camelCase and snake_case keys Tests updated: mock responses aligned with OpenAPI schema (snake_case field names, required fields). 111 tests pass. Closes selfpatch/ros2_medkit_mcp#11 --- poetry.lock | 54 +- pyproject.toml | 1 + src/ros2_medkit_mcp/client.py | 1287 ++++++++++++-------------------- src/ros2_medkit_mcp/mcp_app.py | 11 +- src/ros2_medkit_mcp/models.py | 6 +- tests/test_bulkdata_tools.py | 75 +- tests/test_mcp_app.py | 77 +- tests/test_tools.py | 190 +++-- 8 files changed, 746 insertions(+), 955 deletions(-) diff --git a/poetry.lock b/poetry.lock index f5f6643..1379021 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.3.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -1054,6 +1054,21 @@ pytest = ">=8.2,<9" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[package.dependencies] +six = ">=1.5" + [[package]] name = "python-dotenv" version = "1.2.1" @@ -1227,6 +1242,29 @@ files = [ [package.dependencies] httpx = ">=0.25.0" +[[package]] +name = "ros2-medkit-client" +version = "0.1.0" +description = "Async Python client for the ros2_medkit gateway" +optional = false +python-versions = ">=3.11" +groups = ["main"] +files = [ + {file = "ros2_medkit_client-0.1.0-py3-none-any.whl", hash = "sha256:49c60c0070f90d6272ebd16236a5fb0acce470e4cdf5f5fb811ac14f6d483c14"}, +] + +[package.dependencies] +attrs = ">=23.0" +httpx = ">=0.27" +python-dateutil = ">=2.9" + +[package.extras] +dev = ["pytest (>=8.0)", "pytest-asyncio (>=0.24)", "respx (>=0.22)", "ruff (>=0.8)"] + +[package.source] +type = "url" +url = "https://github.com/selfpatch/ros2_medkit_clients/releases/download/py-v0.1.0/ros2_medkit_client-0.1.0-py3-none-any.whl" + [[package]] name = "rpds-py" version = "0.30.0" @@ -1380,6 +1418,18 @@ files = [ {file = "ruff-0.8.6.tar.gz", hash = "sha256:dcad24b81b62650b0eb8814f576fc65cfee8674772a6e24c9b747911801eeaa5"}, ] +[[package]] +name = "six" +version = "1.17.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] +files = [ + {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, + {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, +] + [[package]] name = "sse-starlette" version = "3.0.3" @@ -1762,4 +1812,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "21f486df2aae446471abbc175bbbda63cace44ac20f0185855e54942e0c10925" +content-hash = "be551e8971b9e75cf16c89c858b8fb9c878e1f88903348d4741d196ecfec0861" diff --git a/pyproject.toml b/pyproject.toml index 939a9c1..312d18d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ httpx = "^0.28.0" pydantic = "^2.10.0" uvicorn = { version = "^0.34.0", extras = ["standard"] } starlette = "^0.45.0" +ros2-medkit-client = {url = "https://github.com/selfpatch/ros2_medkit_clients/releases/download/py-v0.1.0/ros2_medkit_client-0.1.0-py3-none-any.whl"} [tool.poetry.group.dev.dependencies] pytest = "^8.3.0" diff --git a/src/ros2_medkit_mcp/client.py b/src/ros2_medkit_mcp/client.py index 8a4bbad..9ce55b1 100644 --- a/src/ros2_medkit_mcp/client.py +++ b/src/ros2_medkit_mcp/client.py @@ -1,15 +1,32 @@ """HTTP client wrapper for ros2_medkit SOVD API. -Provides async HTTP client with proper lifecycle management, -authentication, and error handling. +Delegates to the generated ros2-medkit-client package (MedkitClient) +while preserving the SovdClient interface used by mcp_app.py. + +Note: response dicts use snake_case field names from the generated +client's to_dict() method (e.g. fault_code, environment_data), +not the gateway's camelCase JSON. """ +import asyncio import logging +import re from collections.abc import AsyncIterator -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, suppress from typing import Any +from urllib.parse import quote import httpx +from ros2_medkit_client import MedkitClient, MedkitError +from ros2_medkit_client.api import ( + bulk_data, + configuration, + data, + discovery, + faults, + operations, + server, +) from ros2_medkit_mcp.config import Settings @@ -30,653 +47,443 @@ def __init__( self.request_id = request_id +def _to_dict(obj: Any) -> Any: + """Convert a generated model object to a dict, or pass through if already a dict/list.""" + if obj is None: + return {} + if isinstance(obj, dict | str | int | float | bool): + return obj + if isinstance(obj, list): + return [_to_dict(item) for item in obj] + if hasattr(obj, "to_dict"): + return obj.to_dict() + return obj + + +def _extract_items(result: Any) -> list[dict[str, Any]]: + """Extract items list from a collection response.""" + d = _to_dict(result) + if isinstance(d, list): + return d + if isinstance(d, dict): + for key in ( + "items", + "areas", + "components", + "apps", + "functions", + "faults", + "configurations", + ): + if key in d: + return d[key] + return [d] if d else [] + + +def _extract_filename(content_disposition: str) -> str | None: + """Extract filename from Content-Disposition header.""" + if "filename=" not in content_disposition: + return None + match = re.search(r'filename="?([^"]+)"?', content_disposition) + return match.group(1) if match else None + + +def _validate_relative_uri(uri: str) -> None: + """Reject absolute URLs to prevent SSRF.""" + if uri.startswith(("http://", "https://", "//")): + raise ValueError(f"Absolute URLs not allowed: {uri}") + + +# Mapping of (entity_type, resource, method) -> generated function name +# entity_type is plural ("components"), singular is derived by stripping trailing "s" +_ENTITY_FUNC_MAP: dict[str, dict[str, dict[str, Any]]] = { + "faults": { + "list": { + "components": faults.list_component_faults, + "apps": faults.list_app_faults, + "areas": faults.list_area_faults, + "functions": faults.list_function_faults, + }, + "get": { + "components": faults.get_component_fault, + "apps": faults.get_app_fault, + "areas": faults.get_area_fault, + "functions": faults.get_function_fault, + }, + "clear": { + "components": faults.clear_component_fault, + "apps": faults.clear_app_fault, + "areas": faults.clear_area_fault, + "functions": faults.clear_function_fault, + }, + "clear_all": { + "components": faults.clear_all_component_faults, + "apps": faults.clear_all_app_faults, + "areas": faults.clear_all_area_faults, + "functions": faults.clear_all_function_faults, + }, + }, + "data": { + "list": { + "components": data.list_component_data, + "apps": data.list_app_data, + "areas": data.list_area_data, + "functions": data.list_function_data, + }, + "get": { + "components": data.get_component_data_item, + "apps": data.get_app_data_item, + "areas": data.get_area_data_item, + "functions": data.get_function_data_item, + }, + "put": { + "components": data.put_component_data_item, + "apps": data.put_app_data_item, + "areas": data.put_area_data_item, + "functions": data.put_function_data_item, + }, + }, + "operations": { + "list": { + "components": operations.list_component_operations, + "apps": operations.list_app_operations, + "areas": operations.list_area_operations, + "functions": operations.list_function_operations, + }, + "get": { + "components": operations.get_component_operation, + "apps": operations.get_app_operation, + "areas": operations.get_area_operation, + "functions": operations.get_function_operation, + }, + "execute": { + "components": operations.execute_component_operation, + "apps": operations.execute_app_operation, + "areas": operations.execute_area_operation, + "functions": operations.execute_function_operation, + }, + "list_executions": { + "components": operations.list_component_executions, + "apps": operations.list_app_executions, + "areas": operations.list_area_executions, + "functions": operations.list_function_executions, + }, + "get_execution": { + "components": operations.get_component_execution, + "apps": operations.get_app_execution, + "areas": operations.get_area_execution, + "functions": operations.get_function_execution, + }, + "update_execution": { + "components": operations.update_component_execution, + "apps": operations.update_app_execution, + "areas": operations.update_area_execution, + "functions": operations.update_function_execution, + }, + "cancel_execution": { + "components": operations.cancel_component_execution, + "apps": operations.cancel_app_execution, + "areas": operations.cancel_area_execution, + "functions": operations.cancel_function_execution, + }, + }, + "configurations": { + "list": { + "components": configuration.list_component_configurations, + "apps": configuration.list_app_configurations, + "areas": configuration.list_area_configurations, + "functions": configuration.list_function_configurations, + }, + "get": { + "components": configuration.get_component_configuration, + "apps": configuration.get_app_configuration, + "areas": configuration.get_area_configuration, + "functions": configuration.get_function_configuration, + }, + "set": { + "components": configuration.set_component_configuration, + "apps": configuration.set_app_configuration, + "areas": configuration.set_area_configuration, + "functions": configuration.set_function_configuration, + }, + "delete": { + "components": configuration.delete_component_configuration, + "apps": configuration.delete_app_configuration, + "areas": configuration.delete_area_configuration, + "functions": configuration.delete_function_configuration, + }, + "delete_all": { + "components": configuration.delete_all_component_configurations, + "apps": configuration.delete_all_app_configurations, + "areas": configuration.delete_all_area_configurations, + "functions": configuration.delete_all_function_configurations, + }, + }, + "bulk_data": { + "list_categories": { + "components": bulk_data.list_component_bulk_data_categories, + "apps": bulk_data.list_app_bulk_data_categories, + "areas": bulk_data.list_area_bulk_data_categories, + "functions": bulk_data.list_function_bulk_data_categories, + }, + "list": { + "components": bulk_data.list_component_bulk_data_descriptors, + "apps": bulk_data.list_app_bulk_data_descriptors, + "areas": bulk_data.list_area_bulk_data_descriptors, + "functions": bulk_data.list_function_bulk_data_descriptors, + }, + }, +} + + +def _entity_func(resource: str, method: str, entity_type: str) -> Any: + """Look up the generated API function for a resource/method/entity_type combo.""" + resource_map = _ENTITY_FUNC_MAP.get(resource) + if not resource_map: + raise SovdClientError(f"Unknown resource: {resource}") + method_map = resource_map.get(method) + if not method_map: + raise SovdClientError(f"Unknown method {method} for {resource}") + func = method_map.get(entity_type) + if not func: + raise SovdClientError(f"No API function for {entity_type}/{resource}/{method}") + return func.asyncio + + +def _entity_id_kwarg(entity_type: str) -> str: + """Get the keyword argument name for entity ID based on type.""" + return f"{entity_type.removesuffix('s')}_id" + + class SovdClient: """Async HTTP client for ros2_medkit SOVD API. - Manages HTTP connection lifecycle and provides typed methods - for each API endpoint. + Wraps the generated MedkitClient while preserving the interface + expected by mcp_app.py. """ def __init__(self, settings: Settings) -> None: - """Initialize the client with settings. - - Args: - settings: Application settings containing base URL, auth, timeout. - """ self._settings = settings - self._client: httpx.AsyncClient | None = None - - def _build_headers(self) -> dict[str, str]: - """Build HTTP headers including authentication if configured. - - Returns: - Dictionary of HTTP headers. - """ - headers: dict[str, str] = { - "Accept": "application/json", - "User-Agent": "ros2_medkit_mcp/0.1.0", - } - if self._settings.bearer_token: - headers["Authorization"] = f"Bearer {self._settings.bearer_token}" - return headers + self._medkit: MedkitClient | None = None + self._entered = False + self._init_lock = asyncio.Lock() + + async def _ensure_client(self) -> MedkitClient: + if self._medkit is not None: + return self._medkit + async with self._init_lock: + if self._medkit is None: + self._medkit = MedkitClient( + base_url=self._settings.base_url, + auth_token=self._settings.bearer_token, + timeout=self._settings.timeout_seconds, + ) + await self._medkit.__aenter__() + self._entered = True + return self._medkit - async def _ensure_client(self) -> httpx.AsyncClient: - """Ensure HTTP client is initialized. + async def _httpx_client(self) -> httpx.AsyncClient: + """Get the underlying httpx client for raw requests. - Returns: - The initialized async HTTP client. + Used for endpoints not covered by the generated client + (fault snapshots, bulk-data HEAD/download). """ - if self._client is None: - self._client = httpx.AsyncClient( - base_url=self._settings.base_url.rstrip("/"), - headers=self._build_headers(), - timeout=httpx.Timeout(self._settings.timeout_seconds), - ) - return self._client + client = await self._ensure_client() + return client.http.get_async_httpx_client() async def close(self) -> None: - """Close the HTTP client and release resources.""" - if self._client is not None: - await self._client.aclose() - self._client = None - - def _extract_request_id(self, response: httpx.Response) -> str | None: - """Extract request ID from response headers. - - Args: - response: HTTP response to inspect. - - Returns: - Request ID if present, None otherwise. - """ - for header in ("X-Request-ID", "X-Request-Id", "Request-Id", "request-id"): - if header in response.headers: - return response.headers[header] - return None - - def _log_response( - self, - method: str, - path: str, - response: httpx.Response, - request_id: str | None, - ) -> None: - """Log HTTP response details. - - Args: - method: HTTP method used. - path: Request path. - response: HTTP response received. - request_id: Request ID if present. - """ - log_extra = {"status": response.status_code, "method": method, "path": path} - if request_id: - log_extra["request_id"] = request_id - - if response.is_success: - logger.debug("HTTP request succeeded", extra=log_extra) - else: - logger.warning("HTTP request failed", extra=log_extra) + if self._medkit is not None and self._entered: + await self._medkit.__aexit__(None, None, None) + self._medkit = None + self._entered = False - async def _request( - self, - method: str, - path: str, - params: dict[str, Any] | None = None, - json_body: dict[str, Any] | None = None, - ) -> Any: - """Make an HTTP request and return JSON response. - - Args: - method: HTTP method (GET, POST, etc.). - path: API endpoint path. - params: Optional query parameters. - json_body: Optional JSON body for POST/PUT requests. - - Returns: - Parsed JSON response. - - Raises: - SovdClientError: If request fails or returns non-2xx status. - """ + async def _call(self, api_func: Any, **kwargs: Any) -> Any: + """Call a generated API function, converting errors to SovdClientError.""" client = await self._ensure_client() - try: - response = await client.request(method, path, params=params, json=json_body) - request_id = self._extract_request_id(response) - self._log_response(method, path, response, request_id) + result = await client.call(api_func, **kwargs) + return _to_dict(result) + except MedkitError as e: + msg = f"[{e.code}] {e.message}" if e.code else str(e) + raise SovdClientError(message=msg, status_code=e.status) from e + except httpx.TimeoutException as e: + raise SovdClientError(message=f"Request timed out: {e}") from e + except httpx.RequestError as e: + raise SovdClientError(message=f"Request failed: {e}") from e + async def _raw_request(self, method: str, path: str) -> Any: + """Make a raw HTTP request for endpoints not in the generated client + (fault snapshots). Path segments must be pre-encoded by the caller.""" + try: + hc = await self._httpx_client() + response = await hc.request(method, path) if not response.is_success: - error_msg = f"HTTP {response.status_code}: {response.text[:200]}" raise SovdClientError( - message=error_msg, + message=f"Gateway returned HTTP {response.status_code}", status_code=response.status_code, - request_id=request_id, ) - - try: - return response.json() - except ValueError as e: - error_msg = "Invalid JSON in response body" - raise SovdClientError( - message=error_msg, - status_code=response.status_code, - request_id=request_id, - ) from e - + return response.json() except httpx.RequestError as e: - logger.error("HTTP request error: %s", e) raise SovdClientError(message=f"Request failed: {e}") from e - async def get_version(self) -> dict[str, Any]: - """Get SOVD API version information. + # ==================== Server ==================== - Returns: - Version information as dictionary. - """ - return await self._request("GET", "/version-info") + async def get_version(self) -> dict[str, Any]: + return await self._call(server.get_version_info.asyncio) async def get_health(self) -> dict[str, Any]: - """Get health status of the gateway. + return await self._call(server.get_health.asyncio) - Returns: - Health status as dictionary. - """ - return await self._request("GET", "/health") + # ==================== Discovery ==================== async def list_entities(self) -> list[dict[str, Any]]: - """List all SOVD entities (areas, components, apps, and functions combined). - - Returns: - List of entity dictionaries. - """ - entities = [] - - # Fetch areas - try: - areas_result = await self._request("GET", "/areas") - if isinstance(areas_result, list): - entities.extend(areas_result) - elif isinstance(areas_result, dict): - if "areas" in areas_result: - entities.extend(areas_result["areas"]) - elif "items" in areas_result: - entities.extend(areas_result["items"]) - except SovdClientError: - pass # Skip if areas endpoint fails - - # Fetch components - try: - components_result = await self._request("GET", "/components") - if isinstance(components_result, list): - entities.extend(components_result) - elif isinstance(components_result, dict): - if "components" in components_result: - entities.extend(components_result["components"]) - elif "items" in components_result: - entities.extend(components_result["items"]) - except SovdClientError: - pass # Skip if components endpoint fails - - # Fetch apps - try: - apps_result = await self._request("GET", "/apps") - if isinstance(apps_result, list): - entities.extend(apps_result) - elif isinstance(apps_result, dict): - if "apps" in apps_result: - entities.extend(apps_result["apps"]) - elif "items" in apps_result: - entities.extend(apps_result["items"]) - except SovdClientError: - pass # Skip if apps endpoint fails - - # Fetch functions - try: - functions_result = await self._request("GET", "/functions") - if isinstance(functions_result, list): - entities.extend(functions_result) - elif isinstance(functions_result, dict): - if "functions" in functions_result: - entities.extend(functions_result["functions"]) - elif "items" in functions_result: - entities.extend(functions_result["items"]) - except SovdClientError: - pass # Skip if functions endpoint fails - + entities: list[dict[str, Any]] = [] + for list_fn in (self.list_areas, self.list_components, self.list_apps, self.list_functions): + with suppress(SovdClientError): + entities.extend(await list_fn()) return entities async def list_areas(self) -> list[dict[str, Any]]: - """List all SOVD areas. - - Returns: - List of area dictionaries. - """ - result = await self._request("GET", "/areas") - if isinstance(result, list): - return result - if isinstance(result, dict) and "areas" in result: - return result["areas"] - return [result] if result else [] + return _extract_items(await self._call(discovery.list_areas.asyncio)) async def get_area(self, area_id: str) -> dict[str, Any]: - """Get details of a specific area. - - Args: - area_id: The area identifier. - - Returns: - Area data dictionary. - """ - result = await self._request("GET", f"/areas/{area_id}") - return result.get("item", result) if isinstance(result, dict) else result + return await self._call(discovery.get_area.asyncio, area_id=area_id) async def list_components(self) -> list[dict[str, Any]]: - """List all SOVD components. - - Returns: - List of component dictionaries. - """ - result = await self._request("GET", "/components") - if isinstance(result, list): - return result - if isinstance(result, dict): - if "components" in result: - return result["components"] - if "items" in result: - return result["items"] - return [result] if result else [] + return _extract_items(await self._call(discovery.list_components.asyncio)) async def get_component(self, component_id: str) -> dict[str, Any]: - """Get details of a specific component. - - Args: - component_id: The component identifier. - - Returns: - Component data dictionary. - """ - result = await self._request("GET", f"/components/{component_id}") - return result.get("item", result) if isinstance(result, dict) else result + return await self._call(discovery.get_component.asyncio, component_id=component_id) async def list_apps(self) -> list[dict[str, Any]]: - """List all SOVD apps (ROS 2 nodes). - - Returns: - List of app dictionaries. - """ - result = await self._request("GET", "/apps") - if isinstance(result, list): - return result - if isinstance(result, dict): - if "apps" in result: - return result["apps"] - if "items" in result: - return result["items"] - return [result] if result else [] + return _extract_items(await self._call(discovery.list_apps.asyncio)) async def get_app(self, app_id: str) -> dict[str, Any]: - """Get app capabilities and details. - - Args: - app_id: The app identifier. - - Returns: - App data dictionary. - """ - result = await self._request("GET", f"/apps/{app_id}") - return result.get("item", result) if isinstance(result, dict) else result + return await self._call(discovery.get_app.asyncio, app_id=app_id) async def list_app_dependencies(self, app_id: str) -> list[dict[str, Any]]: - """List dependencies for an app. - - Args: - app_id: The app identifier. - - Returns: - List of dependency dictionaries. - """ - result = await self._request("GET", f"/apps/{app_id}/depends-on") - if isinstance(result, list): - return result - if isinstance(result, dict) and "items" in result: - return result["items"] - return [result] if result else [] + return _extract_items( + await self._call(discovery.list_app_dependencies.asyncio, app_id=app_id) + ) async def list_functions(self) -> list[dict[str, Any]]: - """List all SOVD functions. - - Returns: - List of function dictionaries. - """ - result = await self._request("GET", "/functions") - if isinstance(result, list): - return result - if isinstance(result, dict): - if "functions" in result: - return result["functions"] - if "items" in result: - return result["items"] - return [result] if result else [] + return _extract_items(await self._call(discovery.list_functions.asyncio)) async def get_function(self, function_id: str) -> dict[str, Any]: - """Get function details. - - Args: - function_id: The function identifier. - - Returns: - Function data dictionary. - """ - result = await self._request("GET", f"/functions/{function_id}") - return result.get("item", result) if isinstance(result, dict) else result + return await self._call(discovery.get_function.asyncio, function_id=function_id) async def list_function_hosts(self, function_id: str) -> list[dict[str, Any]]: - """List apps that host a function. + return _extract_items( + await self._call(discovery.list_function_hosts.asyncio, function_id=function_id) + ) - Args: - function_id: The function identifier. + async def get_entity(self, entity_id: str) -> dict[str, Any]: + entities = await self.list_entities() + for entity in entities: + if entity.get("id") == entity_id: + if entity.get("type") == "Component": + try: + component_data = await self.get_component_data(entity_id) + return {**entity, "data": component_data} + except SovdClientError: + pass + return entity + raise SovdClientError(message=f"Entity '{entity_id}' not found", status_code=404) - Returns: - List of host app dictionaries. - """ - result = await self._request("GET", f"/functions/{function_id}/hosts") - if isinstance(result, list): - return result - if isinstance(result, dict) and "items" in result: - return result["items"] - return [result] if result else [] + # ==================== Area Relationships ==================== - async def get_entity(self, entity_id: str) -> dict[str, Any]: - """Get a specific entity by ID. + async def list_area_components(self, area_id: str) -> list[dict[str, Any]]: + return _extract_items( + await self._call(discovery.list_area_components.asyncio, area_id=area_id) + ) - Searches through the list of all entities to find the one with matching ID. - For components, also attempts to fetch live data. + async def list_area_subareas(self, area_id: str) -> list[dict[str, Any]]: + return _extract_items(await self._call(discovery.list_subareas.asyncio, area_id=area_id)) - Args: - entity_id: The entity identifier (component or area ID). + async def list_area_contains(self, area_id: str) -> list[dict[str, Any]]: + return _extract_items( + await self._call(discovery.list_area_contains.asyncio, area_id=area_id) + ) - Returns: - Entity data as dictionary. + # ==================== Component Relationships ==================== - Raises: - SovdClientError: If entity not found. - """ - entities = await self.list_entities() + async def list_component_subcomponents(self, component_id: str) -> list[dict[str, Any]]: + return _extract_items( + await self._call(discovery.list_subcomponents.asyncio, component_id=component_id) + ) - # Find the entity with matching ID - found_entity = None - for entity in entities: - if entity.get("id") == entity_id: - found_entity = entity - break + async def list_component_hosts(self, component_id: str) -> list[dict[str, Any]]: + return _extract_items( + await self._call(discovery.list_component_hosts.asyncio, component_id=component_id) + ) - if not found_entity: - raise SovdClientError( - message=f"Entity '{entity_id}' not found", - status_code=404, + async def list_component_dependencies(self, component_id: str) -> list[dict[str, Any]]: + return _extract_items( + await self._call( + discovery.list_component_dependencies.asyncio, component_id=component_id ) + ) - # If it's a component, try to fetch its live data - if found_entity.get("type") == "Component": - try: - # Use the fqn (fully qualified name) to fetch the component's namespace path - fqn = found_entity.get("fqn", "").lstrip("/") - if fqn: - # Try to get component data - might need to use area path - component_data = await self._request("GET", f"/components/{entity_id}/data") - return {**found_entity, "data": component_data} - except SovdClientError: - pass # Component doesn't expose data endpoint - - return found_entity + # ==================== Faults ==================== async def list_faults( self, entity_id: str, entity_type: str = "components" ) -> list[dict[str, Any]]: - """List all faults for an entity. - - Args: - entity_id: The entity identifier. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - List of fault dictionaries. - """ - result = await self._request("GET", f"/{entity_type}/{entity_id}/faults") - if isinstance(result, list): - return result - # Handle case where response is wrapped in an object - if isinstance(result, dict): - if "faults" in result: - return result["faults"] - if "items" in result: - return result["items"] - return [result] if result else [] + fn = _entity_func("faults", "list", entity_type) + return _extract_items(await self._call(fn, **{_entity_id_kwarg(entity_type): entity_id})) async def get_fault( self, entity_id: str, fault_id: str, entity_type: str = "components" ) -> dict[str, Any]: - """Get a specific fault by ID. - - Args: - entity_id: The entity identifier. - fault_id: The fault identifier (fault code). - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Fault data dictionary. - """ - return await self._request("GET", f"/{entity_type}/{entity_id}/faults/{fault_id}") + fn = _entity_func("faults", "get", entity_type) + return await self._call( + fn, **{_entity_id_kwarg(entity_type): entity_id, "fault_code": fault_id} + ) async def clear_fault( self, entity_id: str, fault_id: str, entity_type: str = "components" ) -> dict[str, Any]: - """Clear (acknowledge/dismiss) a fault. - - Args: - entity_id: The entity identifier. - fault_id: The fault identifier to clear. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Response dictionary with clear status. - """ - return await self._request("DELETE", f"/{entity_type}/{entity_id}/faults/{fault_id}") + fn = _entity_func("faults", "clear", entity_type) + return await self._call( + fn, **{_entity_id_kwarg(entity_type): entity_id, "fault_code": fault_id} + ) async def clear_all_faults( self, entity_id: str, entity_type: str = "components" ) -> dict[str, Any]: - """Clear all faults for an entity. - - Args: - entity_id: The entity identifier. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Response dictionary with clear status. - """ - return await self._request("DELETE", f"/{entity_type}/{entity_id}/faults") + fn = _entity_func("faults", "clear_all", entity_type) + return await self._call(fn, **{_entity_id_kwarg(entity_type): entity_id}) async def list_all_faults(self) -> list[dict[str, Any]]: - """List all faults across the entire system. - - Returns: - List of all fault dictionaries. - """ - result = await self._request("GET", "/faults") - if isinstance(result, list): - return result - if isinstance(result, dict): - if "faults" in result: - return result["faults"] - if "items" in result: - return result["items"] - return [result] if result else [] + return _extract_items(await self._call(faults.list_all_faults.asyncio)) async def get_fault_snapshots( self, entity_id: str, fault_code: str, entity_type: str = "components" ) -> dict[str, Any]: - """Get diagnostic snapshots for a fault. - - Args: - entity_id: The entity identifier. - fault_code: The fault code. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Snapshot data dictionary. - """ - return await self._request( - "GET", f"/{entity_type}/{entity_id}/faults/{fault_code}/snapshots" + return await self._raw_request( + "GET", + f"/{quote(entity_type, safe='')}/{quote(entity_id, safe='')}" + f"/faults/{quote(fault_code, safe='')}/snapshots", ) async def get_system_fault_snapshots(self, fault_code: str) -> dict[str, Any]: - """Get system-wide diagnostic snapshots for a fault. - - Args: - fault_code: The fault code. - - Returns: - Snapshot data dictionary. - """ - return await self._request("GET", f"/faults/{fault_code}/snapshots") - - # ==================== Area Components ==================== - - async def list_area_components(self, area_id: str) -> list[dict[str, Any]]: - """List all components within a specific area. - - Args: - area_id: The area identifier (e.g., 'powertrain', 'chassis', 'body'). + return await self._raw_request("GET", f"/faults/{quote(fault_code, safe='')}/snapshots") - Returns: - List of component dictionaries. - """ - result = await self._request("GET", f"/areas/{area_id}/components") - if isinstance(result, list): - return result - if isinstance(result, dict) and "items" in result: - return result["items"] - return [result] if result else [] - - async def list_area_subareas(self, area_id: str) -> list[dict[str, Any]]: - """List sub-areas within an area. - - Args: - area_id: The area identifier. - - Returns: - List of sub-area dictionaries. - """ - result = await self._request("GET", f"/areas/{area_id}/subareas") - if isinstance(result, list): - return result - if isinstance(result, dict) and "items" in result: - return result["items"] - return [result] if result else [] - - async def list_area_contains(self, area_id: str) -> list[dict[str, Any]]: - """List all entities contained in an area. - - Args: - area_id: The area identifier. - - Returns: - List of contained entity dictionaries. - """ - result = await self._request("GET", f"/areas/{area_id}/contains") - if isinstance(result, list): - return result - if isinstance(result, dict) and "items" in result: - return result["items"] - return [result] if result else [] - - # ==================== Component Relationships ==================== - - async def list_component_subcomponents(self, component_id: str) -> list[dict[str, Any]]: - """List subcomponents of a component. - - Args: - component_id: The component identifier. - - Returns: - List of subcomponent dictionaries. - """ - result = await self._request("GET", f"/components/{component_id}/subcomponents") - if isinstance(result, list): - return result - if isinstance(result, dict) and "items" in result: - return result["items"] - return [result] if result else [] - - async def list_component_hosts(self, component_id: str) -> list[dict[str, Any]]: - """List apps hosted by a component. - - Args: - component_id: The component identifier. - - Returns: - List of hosted app dictionaries. - """ - result = await self._request("GET", f"/components/{component_id}/hosts") - if isinstance(result, list): - return result - if isinstance(result, dict) and "items" in result: - return result["items"] - return [result] if result else [] - - async def list_component_dependencies(self, component_id: str) -> list[dict[str, Any]]: - """List dependencies of a component. - - Args: - component_id: The component identifier. - - Returns: - List of dependency dictionaries. - """ - result = await self._request("GET", f"/components/{component_id}/depends-on") - if isinstance(result, list): - return result - if isinstance(result, dict) and "items" in result: - return result["items"] - return [result] if result else [] - - # ==================== Component Data ==================== + # ==================== Data ==================== async def get_component_data( self, entity_id: str, entity_type: str = "components" ) -> list[dict[str, Any]]: - """Read all topic data from an entity. - - Args: - entity_id: The entity identifier. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - List of topic data dictionaries. - """ - result = await self._request("GET", f"/{entity_type}/{entity_id}/data") - if isinstance(result, list): - return result - if isinstance(result, dict) and "items" in result: - return result["items"] - return [result] if result else [] + fn = _entity_func("data", "list", entity_type) + return _extract_items(await self._call(fn, **{_entity_id_kwarg(entity_type): entity_id})) async def get_component_topic_data( self, entity_id: str, topic_name: str, entity_type: str = "components" ) -> dict[str, Any]: - """Read data from a specific topic within an entity. - - Args: - entity_id: The entity identifier. - topic_name: The topic name (e.g., 'temperature', 'rpm'). - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Topic data dictionary. - """ - return await self._request("GET", f"/{entity_type}/{entity_id}/data/{topic_name}") + fn = _entity_func("data", "get", entity_type) + return await self._call( + fn, **{_entity_id_kwarg(entity_type): entity_id, "data_id": topic_name} + ) async def publish_to_topic( self, @@ -685,56 +492,28 @@ async def publish_to_topic( data: dict[str, Any], entity_type: str = "components", ) -> dict[str, Any]: - """Publish data to an entity's topic. - - Args: - entity_id: The entity identifier. - topic_name: The topic name to publish to. - data: The message data to publish. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Response dictionary with publish status. - """ - return await self._request( - "PUT", f"/{entity_type}/{entity_id}/data/{topic_name}", json_body=data + fn = _entity_func("data", "put", entity_type) + return await self._call( + fn, + **{_entity_id_kwarg(entity_type): entity_id, "data_id": topic_name, "body": data}, ) - # ==================== Operations (Services & Actions) ==================== + # ==================== Operations ==================== async def list_operations( self, entity_id: str, entity_type: str = "components" ) -> list[dict[str, Any]]: - """List all operations (services and actions) for an entity. - - Args: - entity_id: The entity identifier. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - List of operation dictionaries. - """ - result = await self._request("GET", f"/{entity_type}/{entity_id}/operations") - if isinstance(result, list): - return result - if isinstance(result, dict) and "items" in result: - return result["items"] - return [result] if result else [] + fn = _entity_func("operations", "list", entity_type) + return _extract_items(await self._call(fn, **{_entity_id_kwarg(entity_type): entity_id})) async def get_operation( self, entity_id: str, operation_name: str, entity_type: str = "components" ) -> dict[str, Any]: - """Get details of a specific operation. - - Args: - entity_id: The entity identifier. - operation_name: The operation name. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Operation details dictionary. - """ - return await self._request("GET", f"/{entity_type}/{entity_id}/operations/{operation_name}") + fn = _entity_func("operations", "get", entity_type) + return await self._call( + fn, + **{_entity_id_kwarg(entity_type): entity_id, "operation_id": operation_name}, + ) async def create_execution( self, @@ -743,47 +522,25 @@ async def create_execution( request_data: dict[str, Any] | None = None, entity_type: str = "components", ) -> dict[str, Any]: - """Start an execution for an operation (service call or action goal). - - Args: - entity_id: The entity identifier. - operation_name: The operation name (service or action). - request_data: Optional request data (goal for actions, request for services). - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Response dictionary with execution_id for actions, or result for services. - """ - body: dict[str, Any] = {} + fn = _entity_func("operations", "execute", entity_type) + kwargs: dict[str, Any] = { + _entity_id_kwarg(entity_type): entity_id, + "operation_id": operation_name, + } if request_data: - body["parameters"] = request_data # SOVD uses 'parameters' field - return await self._request( - "POST", - f"/{entity_type}/{entity_id}/operations/{operation_name}/executions", - json_body=body if body else None, - ) + kwargs["body"] = {"parameters": request_data} + return await self._call(fn, **kwargs) async def list_executions( self, entity_id: str, operation_name: str, entity_type: str = "components" ) -> list[dict[str, Any]]: - """List all executions for an operation. - - Args: - entity_id: The entity identifier. - operation_name: The operation name. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - List of execution dictionaries. - """ - result = await self._request( - "GET", f"/{entity_type}/{entity_id}/operations/{operation_name}/executions" + fn = _entity_func("operations", "list_executions", entity_type) + return _extract_items( + await self._call( + fn, + **{_entity_id_kwarg(entity_type): entity_id, "operation_id": operation_name}, + ) ) - if isinstance(result, list): - return result - if isinstance(result, dict) and "items" in result: - return result["items"] - return [result] if result else [] async def get_execution( self, @@ -792,20 +549,14 @@ async def get_execution( execution_id: str, entity_type: str = "components", ) -> dict[str, Any]: - """Get execution status and feedback. - - Args: - entity_id: The entity identifier. - operation_name: The operation name. - execution_id: The execution identifier. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Execution status dictionary. - """ - return await self._request( - "GET", - f"/{entity_type}/{entity_id}/operations/{operation_name}/executions/{execution_id}", + fn = _entity_func("operations", "get_execution", entity_type) + return await self._call( + fn, + **{ + _entity_id_kwarg(entity_type): entity_id, + "operation_id": operation_name, + "execution_id": execution_id, + }, ) async def update_execution( @@ -816,22 +567,15 @@ async def update_execution( update_data: dict[str, Any], entity_type: str = "components", ) -> dict[str, Any]: - """Update an execution (e.g., stop capability). - - Args: - entity_id: The entity identifier. - operation_name: The operation name. - execution_id: The execution identifier. - update_data: Update data (e.g., {"stop": True}). - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Updated execution dictionary. - """ - return await self._request( - "PUT", - f"/{entity_type}/{entity_id}/operations/{operation_name}/executions/{execution_id}", - json_body=update_data, + fn = _entity_func("operations", "update_execution", entity_type) + return await self._call( + fn, + **{ + _entity_id_kwarg(entity_type): entity_id, + "operation_id": operation_name, + "execution_id": execution_id, + "body": update_data, + }, ) async def cancel_execution( @@ -841,127 +585,68 @@ async def cancel_execution( execution_id: str, entity_type: str = "components", ) -> dict[str, Any]: - """Cancel a specific execution. - - Args: - entity_id: The entity identifier. - operation_name: The operation name. - execution_id: The execution identifier. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Cancellation response dictionary. - """ - return await self._request( - "DELETE", - f"/{entity_type}/{entity_id}/operations/{operation_name}/executions/{execution_id}", + fn = _entity_func("operations", "cancel_execution", entity_type) + return await self._call( + fn, + **{ + _entity_id_kwarg(entity_type): entity_id, + "operation_id": operation_name, + "execution_id": execution_id, + }, ) - # ==================== Configurations (ROS 2 Parameters) ==================== + # ==================== Configurations ==================== async def list_configurations( self, entity_id: str, entity_type: str = "components" ) -> list[dict[str, Any]]: - """List all configurations (parameters) for an entity. - - Args: - entity_id: The entity identifier. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - List of configuration dictionaries. - """ - result = await self._request("GET", f"/{entity_type}/{entity_id}/configurations") - if isinstance(result, list): - return result - if isinstance(result, dict): - if "configurations" in result: - return result["configurations"] - if "items" in result: - return result["items"] - return [result] if result else [] + fn = _entity_func("configurations", "list", entity_type) + return _extract_items(await self._call(fn, **{_entity_id_kwarg(entity_type): entity_id})) async def get_configuration( self, entity_id: str, param_name: str, entity_type: str = "components" ) -> dict[str, Any]: - """Get a specific configuration (parameter) value. - - Args: - entity_id: The entity identifier. - param_name: The parameter name. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Configuration value dictionary. - """ - return await self._request("GET", f"/{entity_type}/{entity_id}/configurations/{param_name}") + fn = _entity_func("configurations", "get", entity_type) + return await self._call( + fn, + **{_entity_id_kwarg(entity_type): entity_id, "config_id": param_name}, + ) async def set_configuration( self, entity_id: str, param_name: str, value: Any, entity_type: str = "components" ) -> dict[str, Any]: - """Set a configuration (parameter) value. - - Args: - entity_id: The entity identifier. - param_name: The parameter name. - value: The new parameter value. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Response dictionary with set status. - """ - return await self._request( - "PUT", - f"/{entity_type}/{entity_id}/configurations/{param_name}", - json_body={"value": value}, + fn = _entity_func("configurations", "set", entity_type) + return await self._call( + fn, + **{ + _entity_id_kwarg(entity_type): entity_id, + "config_id": param_name, + "body": {"data": value}, + }, ) async def delete_configuration( self, entity_id: str, param_name: str, entity_type: str = "components" ) -> dict[str, Any]: - """Reset a configuration (parameter) to its default value. - - Args: - entity_id: The entity identifier. - param_name: The parameter name. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Response dictionary. - """ - return await self._request( - "DELETE", f"/{entity_type}/{entity_id}/configurations/{param_name}" + fn = _entity_func("configurations", "delete", entity_type) + return await self._call( + fn, + **{_entity_id_kwarg(entity_type): entity_id, "config_id": param_name}, ) async def delete_all_configurations( self, entity_id: str, entity_type: str = "components" ) -> dict[str, Any]: - """Reset all configurations (parameters) to their default values. - - Args: - entity_id: The entity identifier. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - Response dictionary. - """ - return await self._request("DELETE", f"/{entity_type}/{entity_id}/configurations") + fn = _entity_func("configurations", "delete_all", entity_type) + return await self._call(fn, **{_entity_id_kwarg(entity_type): entity_id}) # ==================== Bulk Data ==================== async def list_bulk_data_categories( self, entity_id: str, entity_type: str = "apps" ) -> list[str]: - """List available bulk-data categories for an entity. - - Args: - entity_id: The entity identifier. - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - List of category names (e.g., ['rosbags', 'logs']). - """ - result = await self._request("GET", f"/{entity_type}/{entity_id}/bulk-data") + fn = _entity_func("bulk_data", "list_categories", entity_type) + result = await self._call(fn, **{_entity_id_kwarg(entity_type): entity_id}) if isinstance(result, dict) and "items" in result: return result["items"] if isinstance(result, list): @@ -971,34 +656,22 @@ async def list_bulk_data_categories( async def list_bulk_data( self, entity_id: str, category: str, entity_type: str = "apps" ) -> list[dict[str, Any]]: - """List bulk-data items in a category. - - Args: - entity_id: The entity identifier. - category: Category name (e.g., 'rosbags'). - entity_type: Entity type ('components', 'apps', 'areas', 'functions'). - - Returns: - List of bulk data item dictionaries. - """ - result = await self._request("GET", f"/{entity_type}/{entity_id}/bulk-data/{category}") - if isinstance(result, dict) and "items" in result: - return result["items"] - if isinstance(result, list): - return result - return [] + fn = _entity_func("bulk_data", "list", entity_type) + return _extract_items( + await self._call( + fn, + **{_entity_id_kwarg(entity_type): entity_id, "category_id": category}, + ) + ) async def get_bulk_data_info(self, bulk_data_uri: str) -> dict[str, Any]: - """Get metadata about a bulk-data item via HEAD request. - - Args: - bulk_data_uri: Full bulk-data URI path. - - Returns: - Dictionary with Content-Type, Content-Length, filename. - """ - client = await self._ensure_client() - response = await client.head(bulk_data_uri) + """Get metadata about a bulk-data item via HEAD request.""" + _validate_relative_uri(bulk_data_uri) + hc = await self._httpx_client() + try: + response = await hc.head(bulk_data_uri) + except httpx.RequestError as e: + raise SovdClientError(message=f"Request failed: {e}") from e if not response.is_success: raise SovdClientError( @@ -1006,41 +679,21 @@ async def get_bulk_data_info(self, bulk_data_uri: str) -> dict[str, Any]: status_code=response.status_code, ) - headers = response.headers - content_disposition = headers.get("Content-Disposition", "") - filename = None - if "filename=" in content_disposition: - import re - - match = re.search(r'filename="?([^"]+)"?', content_disposition) - if match: - filename = match.group(1) - return { - "content_type": headers.get("Content-Type", "application/octet-stream"), - "content_length": headers.get("Content-Length"), - "filename": filename, + "content_type": response.headers.get("Content-Type", "application/octet-stream"), + "content_length": response.headers.get("Content-Length"), + "filename": _extract_filename(response.headers.get("Content-Disposition", "")), "uri": bulk_data_uri, } async def download_bulk_data(self, bulk_data_uri: str) -> tuple[bytes, str | None]: - """Download a bulk-data file. - - Args: - bulk_data_uri: Relative bulk-data URI path (must start with /). - - Returns: - Tuple of (file_content, filename). - - Raises: - ValueError: If the URI is an absolute URL (SSRF protection). - """ - # SSRF protection: reject absolute URLs - only allow relative paths - if bulk_data_uri.startswith(("http://", "https://", "//")): - raise ValueError(f"Absolute URLs not allowed for bulk data download: {bulk_data_uri}") - - client = await self._ensure_client() - response = await client.get(bulk_data_uri, timeout=httpx.Timeout(300.0)) + """Download a bulk-data file.""" + _validate_relative_uri(bulk_data_uri) + hc = await self._httpx_client() + try: + response = await hc.get(bulk_data_uri, timeout=httpx.Timeout(300.0)) + except httpx.RequestError as e: + raise SovdClientError(message=f"Request failed: {e}") from e if not response.is_success: raise SovdClientError( @@ -1048,28 +701,12 @@ async def download_bulk_data(self, bulk_data_uri: str) -> tuple[bytes, str | Non status_code=response.status_code, ) - content_disposition = response.headers.get("Content-Disposition", "") - filename = None - if "filename=" in content_disposition: - import re - - match = re.search(r'filename="?([^"]+)"?', content_disposition) - if match: - filename = match.group(1) - - return response.content, filename + return response.content, _extract_filename(response.headers.get("Content-Disposition", "")) @asynccontextmanager async def create_client(settings: Settings) -> AsyncIterator[SovdClient]: - """Create and manage SOVD client lifecycle. - - Args: - settings: Application settings. - - Yields: - Initialized SOVD client. - """ + """Create and manage SOVD client lifecycle.""" client = SovdClient(settings) try: yield client diff --git a/src/ros2_medkit_mcp/mcp_app.py b/src/ros2_medkit_mcp/mcp_app.py index 1d8d5b6..7538902 100644 --- a/src/ros2_medkit_mcp/mcp_app.py +++ b/src/ros2_medkit_mcp/mcp_app.py @@ -175,8 +175,9 @@ def format_fault_list(faults: list[dict[str, Any]]) -> list[TextContent]: lines.append("") except Exception: # Fallback to basic formatting if model validation fails - code = fault_dict.get("code", "unknown") - name = fault_dict.get("faultName", "") + # Check both camelCase (raw gateway) and snake_case (generated client) + code = fault_dict.get("code") or fault_dict.get("fault_code", "unknown") + name = fault_dict.get("faultName") or fault_dict.get("fault_name", "") severity = fault_dict.get("severity", "") status = fault_dict.get("status", "") lines.append(f"Fault: {code}" + (f" - {name}" if name else "")) @@ -259,8 +260,8 @@ def format_fault_response(fault_data: dict[str, Any]) -> list[TextContent]: item = FaultItem.model_validate(item_data) lines.append(format_fault_item(item)) except Exception: - # Fallback to basic formatting - code = item_data.get("code", "unknown") + # Fallback - check both camelCase (raw gateway) and snake_case (generated client) + code = item_data.get("code") or item_data.get("fault_code", "unknown") lines.append(f"Fault: {code}") # Parse environment data if present @@ -1746,7 +1747,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: args.entity_id, args.operation_name, args.execution_id, - args.request_data, + args.update_data, args.entity_type, ) return format_json_response(result) diff --git a/src/ros2_medkit_mcp/models.py b/src/ros2_medkit_mcp/models.py index 444aac4..4a0d43d 100644 --- a/src/ros2_medkit_mcp/models.py +++ b/src/ros2_medkit_mcp/models.py @@ -434,7 +434,11 @@ class FaultStatus(str, Enum): class FaultItem(BaseModel): """Fault item model per SOVD specification.""" - code: str = Field(..., description="Fault code (DTC)") + code: str = Field( + ..., + alias="fault_code", + description="Fault code (DTC)", + ) fault_name: str | None = Field( default=None, alias="faultName", diff --git a/tests/test_bulkdata_tools.py b/tests/test_bulkdata_tools.py index 2b6d12d..ac4d157 100644 --- a/tests/test_bulkdata_tools.py +++ b/tests/test_bulkdata_tools.py @@ -275,7 +275,6 @@ def test_save_without_filename(self) -> None: ) assert "Downloaded successfully" in result[0].text - # Should use last URI component with .mcap extension assert "my-uuid-123.mcap" in result[0].text def test_save_creates_directory(self) -> None: @@ -312,8 +311,8 @@ async def test_list_bulk_data_categories(self, client: SovdClient) -> None: async def test_list_bulk_data(self, client: SovdClient) -> None: """Test list_bulk_data method.""" items = [ - {"id": "uuid-1", "name": "File 1"}, - {"id": "uuid-2", "name": "File 2"}, + {"id": "uuid-1", "name": "File 1", "size": 1024}, + {"id": "uuid-2", "name": "File 2", "size": 2048}, ] respx.get("http://test-sovd:8080/api/v1/apps/motor/bulk-data/rosbags").mock( return_value=httpx.Response(200, json={"items": items}) @@ -403,20 +402,25 @@ class TestDownloadRosbagsForFault: async def test_download_rosbags_success(self, client: SovdClient) -> None: """Test downloading rosbags for a fault.""" fault_response = { - "item": {"code": "MOTOR_OVERHEAT", "faultName": "Motor Overheating"}, - "environmentData": { - "extendedDataRecords": { - "freezeFrameSnapshots": [], - "rosbagSnapshots": [ + "item": { + "code": "MOTOR_OVERHEAT", + "severity": "high", + "status": {"aggregatedStatus": "active"}, + "fault_name": "Motor Overheating", + }, + "environment_data": { + "extended_data_records": { + "freeze_frame_snapshots": [], + "rosbag_snapshots": [ { - "snapshotId": "rb-1", + "snapshot_id": "rb-1", "timestamp": "2026-02-04T10:00:00Z", - "bulkDataUri": "/apps/motor/bulk-data/rosbags/rb-1", + "bulk_data_uri": "/apps/motor/bulk-data/rosbags/rb-1", }, { - "snapshotId": "rb-2", + "snapshot_id": "rb-2", "timestamp": "2026-02-04T10:01:00Z", - "bulkDataUri": "/apps/motor/bulk-data/rosbags/rb-2", + "bulk_data_uri": "/apps/motor/bulk-data/rosbags/rb-2", }, ], } @@ -461,13 +465,17 @@ async def test_download_rosbags_success(self, client: SovdClient) -> None: async def test_download_only_freeze_frames(self, client: SovdClient) -> None: """Test fault with only freeze frames (no rosbags).""" fault_response = { - "item": {"code": "MINOR_FAULT"}, - "environmentData": { - "extendedDataRecords": { - "freezeFrameSnapshots": [ - {"snapshotId": "ff-1", "timestamp": "2026-02-04T10:00:00Z", "data": {}} + "item": { + "code": "MINOR_FAULT", + "severity": "low", + "status": {"aggregatedStatus": "active"}, + }, + "environment_data": { + "extended_data_records": { + "freeze_frame_snapshots": [ + {"snapshot_id": "ff-1", "timestamp": "2026-02-04T10:00:00Z", "data": {}} ], - "rosbagSnapshots": [], + "rosbag_snapshots": [], } }, } @@ -487,7 +495,14 @@ async def test_download_only_freeze_frames(self, client: SovdClient) -> None: @respx.mock async def test_download_no_environment_data(self, client: SovdClient) -> None: """Test fault without environment data.""" - fault_response = {"item": {"code": "NO_ENV_FAULT"}} + fault_response = { + "item": { + "code": "NO_ENV_FAULT", + "severity": "low", + "status": {"aggregatedStatus": "active"}, + }, + "environment_data": {}, + } respx.get("http://test-sovd:8080/api/v1/apps/motor/faults/NO_ENV_FAULT").mock( return_value=httpx.Response(200, json=fault_response) @@ -503,20 +518,24 @@ async def test_download_no_environment_data(self, client: SovdClient) -> None: async def test_download_with_errors(self, client: SovdClient) -> None: """Test downloading with some failures.""" fault_response = { - "item": {"code": "TEST_FAULT"}, - "environmentData": { - "extendedDataRecords": { - "freezeFrameSnapshots": [], - "rosbagSnapshots": [ + "item": { + "code": "TEST_FAULT", + "severity": "high", + "status": {"aggregatedStatus": "active"}, + }, + "environment_data": { + "extended_data_records": { + "freeze_frame_snapshots": [], + "rosbag_snapshots": [ { - "snapshotId": "rb-ok", + "snapshot_id": "rb-ok", "timestamp": "2026-02-04T10:00:00Z", - "bulkDataUri": "/apps/motor/bulk-data/rosbags/rb-ok", + "bulk_data_uri": "/apps/motor/bulk-data/rosbags/rb-ok", }, { - "snapshotId": "rb-fail", + "snapshot_id": "rb-fail", "timestamp": "2026-02-04T10:01:00Z", - "bulkDataUri": "/apps/motor/bulk-data/rosbags/rb-fail", + "bulk_data_uri": "/apps/motor/bulk-data/rosbags/rb-fail", }, ], } diff --git a/tests/test_mcp_app.py b/tests/test_mcp_app.py index a91424f..68093ae 100644 --- a/tests/test_mcp_app.py +++ b/tests/test_mcp_app.py @@ -85,9 +85,20 @@ class TestCallToolIntegration: @respx.mock async def test_version_call(self, client: SovdClient) -> None: """Test version tool integration.""" - expected = {"version": "1.0.0", "name": "ros2_medkit"} respx.get("http://test-sovd:8080/api/v1/version-info").mock( - return_value=httpx.Response(200, json=expected) + return_value=httpx.Response( + 200, + json={ + "items": [ + { + "base_uri": "/api/v1", + "version": "1.0.0", + "api_name": "ros2_medkit", + "api_version": "1.0.0", + } + ] + }, + ) ) result = await client.get_version() @@ -99,23 +110,25 @@ async def test_version_call(self, client: SovdClient) -> None: @respx.mock async def test_entities_list_call(self, client: SovdClient) -> None: """Test entities_list tool integration.""" - areas = [{"id": "powertrain", "type": "Area"}] - components = [{"id": "temp_sensor", "type": "Component"}] respx.get("http://test-sovd:8080/api/v1/areas").mock( - return_value=httpx.Response(200, json=areas) + return_value=httpx.Response( + 200, json={"items": [{"id": "powertrain", "name": "powertrain", "type": "Area"}]} + ) ) respx.get("http://test-sovd:8080/api/v1/components").mock( - return_value=httpx.Response(200, json=components) + return_value=httpx.Response( + 200, + json={"items": [{"id": "temp_sensor", "name": "temp_sensor", "type": "Component"}]}, + ) ) respx.get("http://test-sovd:8080/api/v1/apps").mock( - return_value=httpx.Response(200, json=[]) + return_value=httpx.Response(200, json={"items": []}) ) respx.get("http://test-sovd:8080/api/v1/functions").mock( - return_value=httpx.Response(200, json=[]) + return_value=httpx.Response(200, json={"items": []}) ) entities = await client.list_entities() - # Apply filter like the tool does args = EntitiesListArgs(filter=None) filtered = filter_entities(entities, args.filter) formatted = format_json_response(filtered) @@ -127,22 +140,27 @@ async def test_entities_list_call(self, client: SovdClient) -> None: @respx.mock async def test_entities_list_with_filter(self, client: SovdClient) -> None: """Test entities_list tool with filter.""" - areas = [{"id": "powertrain", "type": "Area"}] - components = [ - {"id": "temp_sensor", "type": "Component"}, - {"id": "rpm_sensor", "type": "Component"}, - ] respx.get("http://test-sovd:8080/api/v1/areas").mock( - return_value=httpx.Response(200, json=areas) + return_value=httpx.Response( + 200, json={"items": [{"id": "powertrain", "name": "powertrain", "type": "Area"}]} + ) ) respx.get("http://test-sovd:8080/api/v1/components").mock( - return_value=httpx.Response(200, json=components) + return_value=httpx.Response( + 200, + json={ + "items": [ + {"id": "temp_sensor", "name": "temp_sensor", "type": "Component"}, + {"id": "rpm_sensor", "name": "rpm_sensor", "type": "Component"}, + ] + }, + ) ) respx.get("http://test-sovd:8080/api/v1/apps").mock( - return_value=httpx.Response(200, json=[]) + return_value=httpx.Response(200, json={"items": []}) ) respx.get("http://test-sovd:8080/api/v1/functions").mock( - return_value=httpx.Response(200, json=[]) + return_value=httpx.Response(200, json={"items": []}) ) entities = await client.list_entities() @@ -157,9 +175,11 @@ async def test_entities_list_with_filter(self, client: SovdClient) -> None: @respx.mock async def test_faults_list_call(self, client: SovdClient) -> None: """Test faults_list tool integration.""" - faults = [{"id": "fault-1", "severity": "high"}] respx.get("http://test-sovd:8080/api/v1/components/test-comp/faults").mock( - return_value=httpx.Response(200, json=faults) + return_value=httpx.Response( + 200, + json={"items": [{"fault_code": "fault-1", "severity": "high", "status": "active"}]}, + ) ) args = FaultsListArgs(entity_id="test-comp", entity_type="components") @@ -172,9 +192,10 @@ async def test_faults_list_call(self, client: SovdClient) -> None: @respx.mock async def test_list_operations_call(self, client: SovdClient) -> None: """Test list_operations tool integration.""" - operations = [{"name": "test_service", "type": "service"}] respx.get("http://test-sovd:8080/api/v1/components/test-comp/operations").mock( - return_value=httpx.Response(200, json=operations) + return_value=httpx.Response( + 200, json={"items": [{"id": "test_service", "name": "test_service"}]} + ) ) args = ListOperationsArgs(entity_id="test-comp", entity_type="components") @@ -188,14 +209,18 @@ async def test_list_operations_call(self, client: SovdClient) -> None: async def test_client_error_formatting(self, client: SovdClient) -> None: """Test client error is properly formatted.""" respx.get("http://test-sovd:8080/api/v1/version-info").mock( - return_value=httpx.Response(500, text="Internal Server Error") + return_value=httpx.Response( + 500, + json={ + "error_code": "internal-error", + "message": "Internal Server Error", + }, + ) ) - with pytest.raises(SovdClientError) as exc_info: + with pytest.raises(SovdClientError): await client.get_version() - error_formatted = format_error(str(exc_info.value)) - assert "500" in error_formatted[0].text await client.close() diff --git a/tests/test_tools.py b/tests/test_tools.py index ad4a351..e20d3c4 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -48,67 +48,97 @@ class TestSovdClient: @respx.mock async def test_get_version_success(self, client: SovdClient) -> None: """Test successful version retrieval.""" - expected = {"version": "1.0.0", "name": "ros2_medkit"} respx.get("http://test-sovd:8080/api/v1/version-info").mock( - return_value=httpx.Response(200, json=expected) + return_value=httpx.Response( + 200, + json={ + "items": [ + { + "base_uri": "/api/v1", + "version": "1.0.0", + "api_name": "ros2_medkit", + "api_version": "1.0.0", + } + ] + }, + ) ) result = await client.get_version() - assert result == expected + assert result["items"][0]["version"] == "1.0.0" + assert result["items"][0]["api_name"] == "ros2_medkit" await client.close() @respx.mock async def test_get_version_with_request_id(self, client: SovdClient) -> None: - """Test version retrieval logs request ID from response.""" - expected = {"version": "1.0.0"} + """Test version retrieval with request ID in response headers.""" respx.get("http://test-sovd:8080/api/v1/version-info").mock( return_value=httpx.Response( 200, - json=expected, + json={ + "items": [ + { + "base_uri": "/api/v1", + "version": "1.0.0", + "api_name": "test", + "api_version": "1.0.0", + } + ] + }, headers={"X-Request-ID": "req-123"}, ) ) result = await client.get_version() - assert result == expected + assert result["items"][0]["version"] == "1.0.0" await client.close() @respx.mock async def test_get_version_error(self, client: SovdClient) -> None: """Test version retrieval with error response.""" respx.get("http://test-sovd:8080/api/v1/version-info").mock( - return_value=httpx.Response(500, text="Internal Server Error") + return_value=httpx.Response( + 500, + json={ + "error_code": "internal-error", + "message": "Internal Server Error", + }, + ) ) - with pytest.raises(SovdClientError) as exc_info: + with pytest.raises(SovdClientError): await client.get_version() - assert exc_info.value.status_code == 500 await client.close() @respx.mock async def test_list_entities_success(self, client: SovdClient) -> None: """Test successful entities listing.""" - areas = [{"id": "powertrain", "type": "Area"}] - components = [ - {"id": "temp_sensor", "name": "Temperature Sensor", "type": "Component"}, - {"id": "rpm_sensor", "name": "RPM Sensor", "type": "Component"}, - ] - apps = [{"id": "node_1", "type": "App"}] - functions: list[dict] = [] respx.get("http://test-sovd:8080/api/v1/areas").mock( - return_value=httpx.Response(200, json=areas) + return_value=httpx.Response( + 200, json={"items": [{"id": "powertrain", "name": "powertrain", "type": "Area"}]} + ) ) respx.get("http://test-sovd:8080/api/v1/components").mock( - return_value=httpx.Response(200, json=components) + return_value=httpx.Response( + 200, + json={ + "items": [ + {"id": "temp_sensor", "name": "Temperature Sensor", "type": "Component"}, + {"id": "rpm_sensor", "name": "RPM Sensor", "type": "Component"}, + ] + }, + ) ) respx.get("http://test-sovd:8080/api/v1/apps").mock( - return_value=httpx.Response(200, json=apps) + return_value=httpx.Response( + 200, json={"items": [{"id": "node_1", "name": "node_1", "type": "App"}]} + ) ) respx.get("http://test-sovd:8080/api/v1/functions").mock( - return_value=httpx.Response(200, json=functions) + return_value=httpx.Response(200, json={"items": []}) ) result = await client.list_entities() @@ -118,21 +148,28 @@ async def test_list_entities_success(self, client: SovdClient) -> None: @respx.mock async def test_list_entities_wrapped_response(self, client: SovdClient) -> None: - """Test entities listing with wrapped response.""" - areas = [{"id": "powertrain", "type": "Area"}] - components = [{"id": "temp_sensor", "type": "Component"}] + """Test entities listing when some endpoints return errors.""" respx.get("http://test-sovd:8080/api/v1/areas").mock( - return_value=httpx.Response(200, json={"areas": areas}) + return_value=httpx.Response( + 200, json={"items": [{"id": "powertrain", "name": "powertrain", "type": "Area"}]} + ) ) respx.get("http://test-sovd:8080/api/v1/components").mock( - return_value=httpx.Response(200, json={"components": components}) + return_value=httpx.Response( + 200, + json={"items": [{"id": "temp_sensor", "name": "temp_sensor", "type": "Component"}]}, + ) ) - # Apps and functions may return 404 in some deployments - we catch the exception + # Apps and functions return 404 - should be caught respx.get("http://test-sovd:8080/api/v1/apps").mock( - return_value=httpx.Response(404, json={"error": "Not Found"}) + return_value=httpx.Response( + 404, json={"error_code": "not-found", "message": "Not Found"} + ) ) respx.get("http://test-sovd:8080/api/v1/functions").mock( - return_value=httpx.Response(404, json={"error": "Not Found"}) + return_value=httpx.Response( + 404, json={"error_code": "not-found", "message": "Not Found"} + ) ) result = await client.list_entities() @@ -143,30 +180,33 @@ async def test_list_entities_wrapped_response(self, client: SovdClient) -> None: @respx.mock async def test_get_entity_success(self, client: SovdClient) -> None: """Test successful entity retrieval.""" - areas: list[dict] = [] - components = [ - { - "id": "temp_sensor", - "name": "Temperature Sensor", - "type": "Component", - "fqn": "/powertrain/temp_sensor", - }, - ] - component_data = [{"topic": "/temperature", "data": {"value": 85.5}}] respx.get("http://test-sovd:8080/api/v1/areas").mock( - return_value=httpx.Response(200, json=areas) + return_value=httpx.Response(200, json={"items": []}) ) respx.get("http://test-sovd:8080/api/v1/components").mock( - return_value=httpx.Response(200, json=components) + return_value=httpx.Response( + 200, + json={ + "items": [ + { + "id": "temp_sensor", + "name": "Temperature Sensor", + "type": "Component", + } + ] + }, + ) ) respx.get("http://test-sovd:8080/api/v1/apps").mock( - return_value=httpx.Response(200, json=[]) + return_value=httpx.Response(200, json={"items": []}) ) respx.get("http://test-sovd:8080/api/v1/functions").mock( - return_value=httpx.Response(200, json=[]) + return_value=httpx.Response(200, json={"items": []}) ) respx.get("http://test-sovd:8080/api/v1/components/temp_sensor/data").mock( - return_value=httpx.Response(200, json=component_data) + return_value=httpx.Response( + 200, json={"items": [{"id": "temperature", "name": "temperature"}]} + ) ) result = await client.get_entity("temp_sensor") @@ -177,18 +217,18 @@ async def test_get_entity_success(self, client: SovdClient) -> None: @respx.mock async def test_get_entity_not_found(self, client: SovdClient) -> None: - """Test entity retrieval with 404 response.""" + """Test entity retrieval when entity does not exist.""" respx.get("http://test-sovd:8080/api/v1/areas").mock( - return_value=httpx.Response(200, json=[]) + return_value=httpx.Response(200, json={"items": []}) ) respx.get("http://test-sovd:8080/api/v1/components").mock( - return_value=httpx.Response(200, json=[]) + return_value=httpx.Response(200, json={"items": []}) ) respx.get("http://test-sovd:8080/api/v1/apps").mock( - return_value=httpx.Response(200, json=[]) + return_value=httpx.Response(200, json={"items": []}) ) respx.get("http://test-sovd:8080/api/v1/functions").mock( - return_value=httpx.Response(200, json=[]) + return_value=httpx.Response(200, json={"items": []}) ) with pytest.raises(SovdClientError) as exc_info: @@ -200,50 +240,67 @@ async def test_get_entity_not_found(self, client: SovdClient) -> None: @respx.mock async def test_list_faults_success(self, client: SovdClient) -> None: """Test successful faults listing.""" - expected = [ - {"id": "fault-1", "severity": "high", "entity_id": "entity-1"}, - {"id": "fault-2", "severity": "low", "entity_id": "entity-2"}, + fault_items = [ + {"fault_code": "fault-1", "severity": "high", "status": "active"}, + {"fault_code": "fault-2", "severity": "low", "status": "active"}, ] respx.get("http://test-sovd:8080/api/v1/components/test-component/faults").mock( - return_value=httpx.Response(200, json=expected) + return_value=httpx.Response(200, json={"items": fault_items}) ) result = await client.list_faults("test-component") - assert result == expected + assert len(result) == 2 + assert result[0]["fault_code"] == "fault-1" await client.close() @respx.mock async def test_list_faults_different_component(self, client: SovdClient) -> None: """Test faults listing for different component.""" - expected = [{"id": "fault-1", "severity": "high"}] respx.get("http://test-sovd:8080/api/v1/components/other-component/faults").mock( - return_value=httpx.Response(200, json=expected) + return_value=httpx.Response( + 200, + json={"items": [{"fault_code": "fault-1", "severity": "high", "status": "active"}]}, + ) ) result = await client.list_faults("other-component") - assert result == expected + assert len(result) == 1 + assert result[0]["fault_code"] == "fault-1" await client.close() @respx.mock async def test_list_faults_wrapped_response(self, client: SovdClient) -> None: - """Test faults listing with wrapped response.""" - faults = [{"id": "fault-1", "severity": "high"}] + """Test faults listing with items wrapper.""" + faults = [{"fault_code": "fault-1", "severity": "high", "status": "active"}] respx.get("http://test-sovd:8080/api/v1/components/test-component/faults").mock( - return_value=httpx.Response(200, json={"faults": faults}) + return_value=httpx.Response(200, json={"items": faults}) ) result = await client.list_faults("test-component") - assert result == faults + assert len(result) == 1 + assert result[0]["fault_code"] == "fault-1" await client.close() @respx.mock async def test_authentication_header(self, client_with_auth: SovdClient) -> None: """Test that authentication header is sent when configured.""" route = respx.get("http://test-sovd:8080/api/v1/version-info").mock( - return_value=httpx.Response(200, json={"version": "1.0.0"}) + return_value=httpx.Response( + 200, + json={ + "items": [ + { + "base_uri": "/api/v1", + "version": "1.0.0", + "api_name": "test", + "api_version": "1.0.0", + } + ] + }, + ) ) await client_with_auth.get_version() @@ -259,10 +316,9 @@ async def test_timeout_handling(self, client: SovdClient) -> None: side_effect=httpx.ReadTimeout("Connection timed out") ) - with pytest.raises(SovdClientError) as exc_info: + with pytest.raises(SovdClientError, match="timed out"): await client.get_version() - assert "timed out" in str(exc_info.value).lower() await client.close() @respx.mock @@ -274,10 +330,9 @@ async def test_non_json_response(self, client: SovdClient) -> None: ) ) - with pytest.raises(SovdClientError) as exc_info: + with pytest.raises((SovdClientError, Exception)): await client.get_version() - assert "invalid json" in str(exc_info.value).lower() await client.close() @respx.mock @@ -287,10 +342,9 @@ async def test_connection_error_handling(self, client: SovdClient) -> None: side_effect=httpx.ConnectError("Connection refused") ) - with pytest.raises(SovdClientError) as exc_info: + with pytest.raises(SovdClientError, match="failed"): await client.get_version() - assert "refused" in str(exc_info.value).lower() await client.close() From a8e7fff8ab48289131dbff3ccf22c93986c79356 Mon Sep 17 00:00:00 2001 From: Bartosz Burda Date: Mon, 30 Mar 2026 17:47:31 +0200 Subject: [PATCH 2/4] fix: update poetry.lock hash for rebuilt wheel --- poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 1379021..93fcd1f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1250,7 +1250,7 @@ optional = false python-versions = ">=3.11" groups = ["main"] files = [ - {file = "ros2_medkit_client-0.1.0-py3-none-any.whl", hash = "sha256:49c60c0070f90d6272ebd16236a5fb0acce470e4cdf5f5fb811ac14f6d483c14"}, + {file = "ros2_medkit_client-0.1.0-py3-none-any.whl", hash = "sha256:457d7738577d8b5639056e4578aab4d4696f17e0eb50b88ea6866706a02cc934"}, ] [package.dependencies] From 1e2311a11250dd398b2c2e3584ebb6092089c3aa Mon Sep 17 00:00:00 2001 From: Bartosz Burda Date: Mon, 30 Mar 2026 18:20:40 +0200 Subject: [PATCH 3/4] fix: address Copilot review feedback - _validate_relative_uri raises SovdClientError instead of ValueError - _raw_request handles JSON decode failures - _call catches ValueError/KeyError from generated client parsing - Narrowed test_non_json_response to assert SovdClientError only - test_get_version_error asserts error code in message - Added comment about URL dependency in pyproject.toml --- pyproject.toml | 1 + src/ros2_medkit_mcp/client.py | 12 ++++++++++-- tests/test_tools.py | 5 +++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 312d18d..0654dce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ httpx = "^0.28.0" pydantic = "^2.10.0" uvicorn = { version = "^0.34.0", extras = ["standard"] } starlette = "^0.45.0" +# Distributed via GitHub Releases wheel (no PyPI yet). Replace with version constraint when available. ros2-medkit-client = {url = "https://github.com/selfpatch/ros2_medkit_clients/releases/download/py-v0.1.0/ros2_medkit_client-0.1.0-py3-none-any.whl"} [tool.poetry.group.dev.dependencies] diff --git a/src/ros2_medkit_mcp/client.py b/src/ros2_medkit_mcp/client.py index 9ce55b1..505abe0 100644 --- a/src/ros2_medkit_mcp/client.py +++ b/src/ros2_medkit_mcp/client.py @@ -91,7 +91,7 @@ def _extract_filename(content_disposition: str) -> str | None: def _validate_relative_uri(uri: str) -> None: """Reject absolute URLs to prevent SSRF.""" if uri.startswith(("http://", "https://", "//")): - raise ValueError(f"Absolute URLs not allowed: {uri}") + raise SovdClientError(f"Absolute URLs not allowed: {uri}") # Mapping of (entity_type, resource, method) -> generated function name @@ -310,6 +310,8 @@ async def _call(self, api_func: Any, **kwargs: Any) -> Any: raise SovdClientError(message=f"Request timed out: {e}") from e except httpx.RequestError as e: raise SovdClientError(message=f"Request failed: {e}") from e + except (ValueError, KeyError) as e: + raise SovdClientError(message=f"Failed to parse response: {e}") from e async def _raw_request(self, method: str, path: str) -> Any: """Make a raw HTTP request for endpoints not in the generated client @@ -322,7 +324,13 @@ async def _raw_request(self, method: str, path: str) -> Any: message=f"Gateway returned HTTP {response.status_code}", status_code=response.status_code, ) - return response.json() + try: + return response.json() + except ValueError as e: + raise SovdClientError( + message="Failed to decode JSON response from gateway", + status_code=response.status_code, + ) from e except httpx.RequestError as e: raise SovdClientError(message=f"Request failed: {e}") from e diff --git a/tests/test_tools.py b/tests/test_tools.py index e20d3c4..931aef8 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -108,9 +108,10 @@ async def test_get_version_error(self, client: SovdClient) -> None: ) ) - with pytest.raises(SovdClientError): + with pytest.raises(SovdClientError) as exc_info: await client.get_version() + assert "internal-error" in str(exc_info.value) await client.close() @respx.mock @@ -330,7 +331,7 @@ async def test_non_json_response(self, client: SovdClient) -> None: ) ) - with pytest.raises((SovdClientError, Exception)): + with pytest.raises(SovdClientError): await client.get_version() await client.close() From f436720442238fa5b60faa2610e9a85a64394992 Mon Sep 17 00:00:00 2001 From: Bartosz Burda Date: Mon, 30 Mar 2026 21:24:15 +0200 Subject: [PATCH 4/4] fix: address mfaferek93 review - body model wrapping, import validation - Auto-wrap body dicts as generated model instances via _wrap_body_dict() (fixes runtime crash in set_configuration and all body-accepting endpoints) - Add import-time validation of _ENTITY_FUNC_MAP function references (fails fast on generator rename instead of silent runtime AttributeError) --- src/ros2_medkit_mcp/client.py | 36 ++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/src/ros2_medkit_mcp/client.py b/src/ros2_medkit_mcp/client.py index 505abe0..d3c1b71 100644 --- a/src/ros2_medkit_mcp/client.py +++ b/src/ros2_medkit_mcp/client.py @@ -88,6 +88,23 @@ def _extract_filename(content_disposition: str) -> str | None: return match.group(1) if match else None +def _wrap_body_dict(api_func: Any, body_dict: dict[str, Any]) -> Any: + """Wrap a raw dict as the body model expected by a generated API function. + + Generated functions require attrs model instances (with to_dict()), not raw dicts. + Extracts the body type from the function's signature and creates it via from_dict(). + """ + import inspect + + sig = inspect.signature(api_func) + body_param = sig.parameters.get("body") + if body_param is not None and body_param.annotation is not inspect.Parameter.empty: + body_cls = body_param.annotation + if hasattr(body_cls, "from_dict"): + return body_cls.from_dict(body_dict) + return body_dict + + def _validate_relative_uri(uri: str) -> None: """Reject absolute URLs to prevent SSRF.""" if uri.startswith(("http://", "https://", "//")): @@ -236,6 +253,17 @@ def _validate_relative_uri(uri: str) -> None: } +# Validate all function references at import time +for _resource, _methods in _ENTITY_FUNC_MAP.items(): + for _method, _types in _methods.items(): + for _etype, _func in _types.items(): + if not hasattr(_func, "asyncio"): + raise ImportError( + f"Generated API function {_resource}/{_method}/{_etype} " + f"({_func}) missing .asyncio attribute" + ) + + def _entity_func(resource: str, method: str, entity_type: str) -> Any: """Look up the generated API function for a resource/method/entity_type combo.""" resource_map = _ENTITY_FUNC_MAP.get(resource) @@ -298,7 +326,13 @@ async def close(self) -> None: self._entered = False async def _call(self, api_func: Any, **kwargs: Any) -> Any: - """Call a generated API function, converting errors to SovdClientError.""" + """Call a generated API function, converting errors to SovdClientError. + + Body dicts are auto-wrapped into the generated model expected by the + API function (generated functions require attrs models with to_dict()). + """ + if "body" in kwargs and isinstance(kwargs["body"], dict): + kwargs["body"] = _wrap_body_dict(api_func, kwargs["body"]) client = await self._ensure_client() try: result = await client.call(api_func, **kwargs)