From 552ae02d9f312ab94df1a27c50aaf692f7fa2dcd Mon Sep 17 00:00:00 2001 From: "zhiheng.liu" Date: Sat, 21 Mar 2026 20:23:21 +0800 Subject: [PATCH] fix(server): handle CancelledError during shutdown paths --- openviking/server/app.py | 10 ++++- openviking/server/routers/sessions.py | 3 ++ openviking/session/memory_deduplicator.py | 15 +++++++ tests/server/test_server_health.py | 16 +++++++ tests/session/test_memory_dedup_actions.py | 49 ++++++++++++++++++++++ tests/test_session_task_tracking.py | 21 ++++++++++ tests/transaction/test_lock_manager.py | 13 ++++++ 7 files changed, 125 insertions(+), 2 deletions(-) diff --git a/openviking/server/app.py b/openviking/server/app.py index b7a68cf3a..f2b379b96 100644 --- a/openviking/server/app.py +++ b/openviking/server/app.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """FastAPI application for OpenViking HTTP Server.""" +import asyncio import time from contextlib import asynccontextmanager from typing import Callable, Optional @@ -110,8 +111,13 @@ async def lifespan(app: FastAPI): set_prometheus_observer(None) task_tracker.stop_cleanup_loop() if owns_service and service: - await service.close() - logger.info("OpenVikingService closed") + try: + await service.close() + logger.info("OpenVikingService closed") + except asyncio.CancelledError as e: + logger.warning(f"OpenVikingService close cancelled during shutdown: {e}") + except Exception as e: + logger.warning(f"OpenVikingService close failed during shutdown: {e}") app = FastAPI( title="OpenViking API", diff --git a/openviking/server/routers/sessions.py b/openviking/server/routers/sessions.py index bffcf9348..c196a7a7d 100644 --- a/openviking/server/routers/sessions.py +++ b/openviking/server/routers/sessions.py @@ -250,6 +250,9 @@ async def _background_commit_tracked( task_id, result.get("memories_extracted", 0), ) + except asyncio.CancelledError: + tracker.fail(task_id, "Background commit cancelled during shutdown") + logger.warning("Background commit cancelled: session=%s task=%s", session_id, task_id) except Exception as exc: tracker.fail(task_id, str(exc)) logger.exception("Background commit failed: session=%s task=%s", session_id, task_id) diff --git a/openviking/session/memory_deduplicator.py b/openviking/session/memory_deduplicator.py index f88775d8e..f0ccd74ee 100644 --- a/openviking/session/memory_deduplicator.py +++ b/openviking/session/memory_deduplicator.py @@ -7,6 +7,7 @@ per-existing merge/delete actions. """ +import asyncio import copy import re from dataclasses import dataclass @@ -90,6 +91,10 @@ def __init__( config = get_openviking_config() self.embedder = config.embedding.get_embedder() + def _is_shutdown_in_progress(self) -> bool: + """Whether dedup is running during storage shutdown.""" + return bool(getattr(self.vikingdb, "is_closing", False)) + async def deduplicate( self, candidate: CandidateMemory, @@ -221,6 +226,11 @@ async def _find_similar_memories( return similar, query_vector + except asyncio.CancelledError as e: + if not self._is_shutdown_in_progress(): + raise + logger.warning(f"Vector search cancelled during dedup prefilter: {e}") + return [], query_vector except Exception as e: logger.warning(f"Vector search failed: {e}") return [], query_vector @@ -289,6 +299,11 @@ async def _llm_decision( logger.debug("Dedup LLM parsed payload: %s", data) return self._parse_decision_payload(data, similar_memories, candidate) + except asyncio.CancelledError as e: + if not self._is_shutdown_in_progress(): + raise + logger.warning(f"LLM dedup decision cancelled: {e}") + return DedupDecision.CREATE, f"LLM cancelled: {e}", [] except Exception as e: logger.warning(f"LLM dedup decision failed: {e}") return DedupDecision.CREATE, f"LLM failed: {e}", [] diff --git a/tests/server/test_server_health.py b/tests/server/test_server_health.py index fff93fe2d..6086360cd 100644 --- a/tests/server/test_server_health.py +++ b/tests/server/test_server_health.py @@ -3,8 +3,13 @@ """Tests for server infrastructure: health, system status, middleware, error handling.""" +import asyncio + import httpx +from openviking.server.app import create_app +from openviking.server.config import ServerConfig + async def test_health_endpoint(client: httpx.AsyncClient): resp = await client.get("/health") @@ -40,3 +45,14 @@ async def test_openviking_error_handler(client: httpx.AsyncClient): async def test_404_for_unknown_route(client: httpx.AsyncClient): resp = await client.get("/this/route/does/not/exist") assert resp.status_code == 404 + + +async def test_lifespan_shutdown_ignores_cancelled_service_close(): + class _Service: + async def close(self): + raise asyncio.CancelledError("shutdown") + + app = create_app(config=ServerConfig(), service=_Service()) + + async with app.router.lifespan_context(app): + pass diff --git a/tests/session/test_memory_dedup_actions.py b/tests/session/test_memory_dedup_actions.py index ac2739652..36aa6fe96 100644 --- a/tests/session/test_memory_dedup_actions.py +++ b/tests/session/test_memory_dedup_actions.py @@ -1,6 +1,7 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: Apache-2.0 +import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -268,6 +269,54 @@ class _DummyConfig: assert similar[4].uri in existing_text assert similar[5].uri not in existing_text + @pytest.mark.asyncio + async def test_llm_decision_falls_back_to_create_on_cancelled_error(self): + dedup = MemoryDeduplicator(vikingdb=_DummyVikingDB()) + dedup.vikingdb.is_closing = True + + class _DummyVLM: + def is_available(self): + return True + + async def get_completion_async(self, _prompt): + raise asyncio.CancelledError("llm shutdown") + + class _DummyConfig: + vlm = _DummyVLM() + + with patch( + "openviking.session.memory_deduplicator.get_openviking_config", + return_value=_DummyConfig(), + ): + decision, reason, actions = await dedup._llm_decision(_make_candidate(), []) + + assert decision == DedupDecision.CREATE + assert "cancelled" in reason.lower() + assert actions == [] + + @pytest.mark.asyncio + async def test_llm_decision_reraises_cancelled_error_when_not_shutting_down(self): + dedup = MemoryDeduplicator(vikingdb=_DummyVikingDB()) + + class _DummyVLM: + def is_available(self): + return True + + async def get_completion_async(self, _prompt): + raise asyncio.CancelledError("llm shutdown") + + class _DummyConfig: + vlm = _DummyVLM() + + with ( + patch( + "openviking.session.memory_deduplicator.get_openviking_config", + return_value=_DummyConfig(), + ), + pytest.raises(asyncio.CancelledError), + ): + await dedup._llm_decision(_make_candidate(), []) + @pytest.mark.asyncio async def test_find_similar_includes_batch_memories(self): """Batch memory with high cosine similarity appears in results.""" diff --git a/tests/test_session_task_tracking.py b/tests/test_session_task_tracking.py index 1306d5003..b3cdc131c 100644 --- a/tests/test_session_task_tracking.py +++ b/tests/test_session_task_tracking.py @@ -351,3 +351,24 @@ async def leaky_commit(_sid, _ctx): error = task_resp.json()["result"]["error"] assert "superSecretKey" not in error assert "[REDACTED]" in error + + +async def test_cancelled_background_commit_is_marked_failed(api_client): + """Cancelled background commits should not surface as unhandled task crashes.""" + client, service = api_client + session_id = await _new_session_with_message(client) + + async def cancelled_commit(_sid, _ctx): + raise asyncio.CancelledError() + + service.sessions.commit_async = cancelled_commit + + resp = await client.post(f"/api/v1/sessions/{session_id}/commit", params={"wait": False}) + task_id = resp.json()["result"]["task_id"] + + await asyncio.sleep(0.2) + + task_resp = await client.get(f"/api/v1/tasks/{task_id}") + result = task_resp.json()["result"] + assert result["status"] == "failed" + assert "cancelled" in result["error"].lower() diff --git a/tests/transaction/test_lock_manager.py b/tests/transaction/test_lock_manager.py index e30f724bc..83f461802 100644 --- a/tests/transaction/test_lock_manager.py +++ b/tests/transaction/test_lock_manager.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for LockManager.""" +import asyncio import uuid +from unittest.mock import AsyncMock, MagicMock import pytest @@ -86,3 +88,14 @@ async def test_nonexistent_path_fails(self, lm): handle = lm.create_handle() ok = await lm.acquire_point(handle, "/local/nonexistent-xyz") assert ok is False + + async def test_recover_pending_redo_preserves_cancelled_error(self, lm): + lm._redo_log = MagicMock() + lm._redo_log.list_pending.return_value = ["redo-task"] + lm._redo_log.read.return_value = {"archive_uri": "a", "session_uri": "b"} + lm._redo_session_memory = AsyncMock(side_effect=asyncio.CancelledError("shutdown")) + + with pytest.raises(asyncio.CancelledError): + await lm._recover_pending_redo() + + lm._redo_log.mark_done.assert_not_called()