diff --git a/src/llama_stack/core/routers/tool_runtime.py b/src/llama_stack/core/routers/tool_runtime.py index 9769e963b5..f149f5a45c 100644 --- a/src/llama_stack/core/routers/tool_runtime.py +++ b/src/llama_stack/core/routers/tool_runtime.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import time from typing import Any @@ -83,6 +84,19 @@ async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any], authorizatio return result + except asyncio.CancelledError: + # Record cancellation metrics + duration = time.perf_counter() - start_time + if metric_attrs: + error_attrs = {**metric_attrs, "status": "error"} + else: + error_attrs = create_tool_metric_attributes( + tool_name=tool_name, + status="error", + ) + tool_invocations_total.add(1, error_attrs) + tool_duration.record(duration, error_attrs) + raise except Exception: # Record error metrics duration = time.perf_counter() - start_time diff --git a/src/llama_stack/core/routers/vector_io.py b/src/llama_stack/core/routers/vector_io.py index 5737cd2c2b..4388f349b4 100644 --- a/src/llama_stack/core/routers/vector_io.py +++ b/src/llama_stack/core/routers/vector_io.py @@ -185,6 +185,12 @@ async def insert_chunks( vector_insert_duration.record(duration, metric_attrs) vector_chunks_processed_total.add(num_chunks, metric_attrs) return result + except asyncio.CancelledError: + duration = time.perf_counter() - start_time + error_attrs = {**metric_attrs, "status": "error"} + vector_inserts_total.add(1, error_attrs) + vector_insert_duration.record(duration, metric_attrs) + raise except Exception: duration = time.perf_counter() - start_time error_attrs = {**metric_attrs, "status": "error"} @@ -232,6 +238,12 @@ async def query_chunks( vector_queries_total.add(1, success_attrs) vector_retrieval_duration.record(duration, metric_attrs) return result + except asyncio.CancelledError: + duration = time.perf_counter() - start_time + error_attrs = {**metric_attrs, "status": "error"} + vector_queries_total.add(1, error_attrs) + vector_retrieval_duration.record(duration, metric_attrs) + raise except Exception: duration = time.perf_counter() - start_time error_attrs = {**metric_attrs, "status": "error"} @@ -441,6 +453,9 @@ async def openai_delete_vector_store( result = await self.routing_table.openai_delete_vector_store(vector_store_id) vector_deletes_total.add(1, {**metric_attrs, "status": "success"}) return result + except asyncio.CancelledError: + vector_deletes_total.add(1, {**metric_attrs, "status": "error"}) + raise except Exception: vector_deletes_total.add(1, {**metric_attrs, "status": "error"}) raise @@ -485,6 +500,12 @@ async def openai_search_vector_store( vector_queries_total.add(1, success_attrs) vector_retrieval_duration.record(duration, metric_attrs) return result + except asyncio.CancelledError: + duration = time.perf_counter() - start_time + error_attrs = {**metric_attrs, "status": "error"} + vector_queries_total.add(1, error_attrs) + vector_retrieval_duration.record(duration, metric_attrs) + raise except Exception: duration = time.perf_counter() - start_time error_attrs = {**metric_attrs, "status": "error"} @@ -528,6 +549,13 @@ async def openai_attach_file_to_vector_store( vector_inserts_total.add(1, success_attrs) vector_insert_duration.record(duration, metric_attrs) return result + except asyncio.CancelledError: + duration = time.perf_counter() - start_time + error_attrs = {**metric_attrs, "status": "error"} + vector_files_total.add(1, error_attrs) + vector_inserts_total.add(1, error_attrs) + vector_insert_duration.record(duration, metric_attrs) + raise except Exception: duration = time.perf_counter() - start_time error_attrs = {**metric_attrs, "status": "error"} @@ -617,6 +645,9 @@ async def openai_delete_vector_store_file( ) vector_deletes_total.add(1, {**metric_attrs, "status": "success"}) return result + except asyncio.CancelledError: + vector_deletes_total.add(1, {**metric_attrs, "status": "error"}) + raise except Exception: vector_deletes_total.add(1, {**metric_attrs, "status": "error"}) raise diff --git a/tests/unit/telemetry/test_tool_runtime_metrics.py b/tests/unit/telemetry/test_tool_runtime_metrics.py index 6ffd4fb913..37b24cbfeb 100644 --- a/tests/unit/telemetry/test_tool_runtime_metrics.py +++ b/tests/unit/telemetry/test_tool_runtime_metrics.py @@ -6,6 +6,7 @@ """Unit tests for tool runtime metrics.""" +import asyncio from unittest.mock import AsyncMock, MagicMock import pytest @@ -190,3 +191,25 @@ async def test_tool_runtime_metrics_error(self): mock_provider.invoke_tool.assert_called_once() # Note: Error metrics (status="error") would be recorded and exported + + async def test_tool_runtime_metrics_cancelled_error(self): + """Test that cancelled tool invocations record error metrics correctly.""" + mock_routing_table = MagicMock() + + mock_provider = AsyncMock() + mock_provider.__provider_id__ = "brave-search::impl" + mock_provider.invoke_tool.side_effect = asyncio.CancelledError() + + mock_routing_table.get_provider_impl = AsyncMock(return_value=mock_provider) + mock_routing_table.tool_to_toolgroup = {"web_search": "websearch"} + + router = ToolRuntimeRouter(routing_table=mock_routing_table) + + with pytest.raises(asyncio.CancelledError): + await router.invoke_tool( + tool_name="web_search", + kwargs={"query": "test query"}, + authorization=None, + ) + + mock_provider.invoke_tool.assert_called_once() diff --git a/tests/unit/telemetry/test_vector_io_metrics.py b/tests/unit/telemetry/test_vector_io_metrics.py index 78f3ce5030..91545e2ae8 100644 --- a/tests/unit/telemetry/test_vector_io_metrics.py +++ b/tests/unit/telemetry/test_vector_io_metrics.py @@ -6,6 +6,7 @@ """Unit tests for vector IO metrics.""" +import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -190,6 +191,26 @@ async def test_insert_chunks_records_error_metrics(self): attrs = mock_counter.call_args[0][1] assert attrs["status"] == "error" + async def test_insert_chunks_records_cancelled_error_metrics(self): + router, mock_rt = self._create_mock_router() + mock_rt.insert_chunks = AsyncMock(side_effect=asyncio.CancelledError()) + + mock_request = MagicMock() + mock_request.vector_store_id = "vs_test" + mock_request.chunks = [MagicMock(document_id="doc_1")] + mock_request.ttl_seconds = None + + with ( + patch.object(vector_inserts_total, "add") as mock_counter, + patch.object(vector_insert_duration, "record"), + ): + with pytest.raises(asyncio.CancelledError): + await router.insert_chunks(mock_request) + + mock_counter.assert_called_once() + attrs = mock_counter.call_args[0][1] + assert attrs["status"] == "error" + async def test_query_chunks_records_metrics(self): router, mock_rt = self._create_mock_router() mock_result = MagicMock()