From 4b5a840e2d53a6ea309badeb88e90db0b35b64a1 Mon Sep 17 00:00:00 2001 From: Mashhood Siddiqui Date: Mon, 28 Jul 2025 03:18:24 +0530 Subject: [PATCH] feat(qdrant): improve reliability and observability in qdrant_retrieve MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • switch to asyncio.Lock client init to avoid deadlocks • add timeout handling for query_points and scroll operations • validate endpoint name with clear ValueError • emit standard logger.info/error alongside log_with_context • normalize retrieve_item return type to empty list on “not found” --- .../retrieval_providers/qdrant_retrieve.py | 180 +++++++++++++----- 1 file changed, 133 insertions(+), 47 deletions(-) diff --git a/code/python/retrieval_providers/qdrant_retrieve.py b/code/python/retrieval_providers/qdrant_retrieve.py index b7d4b8202..032281192 100644 --- a/code/python/retrieval_providers/qdrant_retrieve.py +++ b/code/python/retrieval_providers/qdrant_retrieve.py @@ -3,15 +3,19 @@ """ WARNING: This code is under development and may undergo changes in future releases. -Backwards compatibility is not guaranteed at this time. +Backwards compatibility not guaranteed at this time. """ -from typing import Dict -from qdrant_client import AsyncQdrantClient -from qdrant_client.http import models +import asyncio import json import time import threading +from typing import Any, Dict, List, Optional + +from qdrant_client import AsyncQdrantClient +from qdrant_client.http import models +from qdrant_client.http.exceptions import UnexpectedResponse + from core.embedding import get_embedding from core.config import CONFIG from misc.logger.logging_config_helper import get_configured_logger @@ -19,14 +23,14 @@ logger = get_configured_logger("qdrant_retrieve") -_client_lock = threading.Lock() +# Use an asyncio.Lock instead of threading.Lock to avoid deadlock in async contexts +_client_lock = asyncio.Lock() qdrant_clients: Dict[str, AsyncQdrantClient] = {} def _create_client_params(endpoint_config): """Extract client parameters from endpoint config.""" params = {} - url = endpoint_config.api_endpoint path = endpoint_config.database_path api_key = endpoint_config.api_key @@ -38,37 +42,41 @@ def _create_client_params(endpoint_config): elif path: params["path"] = path else: - raise ValueError("Either `api_endpoint_env` or `database_path` must be set.") + raise ValueError("Either `api_endpoint` or `database_path` must be set.") return params -async def initialize_client(endpoint_name=None): +async def initialize_client(endpoint_name: Optional[str] = None): """Initialize Qdrant client.""" - global qdrant_clients endpoint_name = endpoint_name or CONFIG.write_endpoint - with _client_lock: + # Validate endpoint exists + if endpoint_name not in CONFIG.retrieval_endpoints: + raise ValueError(f"Unknown Qdrant endpoint: {endpoint_name}") + + async with _client_lock: if endpoint_name not in qdrant_clients: logger.info(f"Initializing Qdrant client for endpoint: {endpoint_name}") try: endpoint_config = CONFIG.retrieval_endpoints[endpoint_name] - params = _create_client_params(endpoint_config) + client = AsyncQdrantClient(**params) + qdrant_clients[endpoint_name] = client - qdrant_clients[endpoint_name] = AsyncQdrantClient(**params) - logger.info( - f"Successfully initialized Qdrant client for {endpoint_name}" - ) - - await qdrant_clients[endpoint_name].get_collections() + # Test connection + await asyncio.wait_for(client.get_collections(), timeout=10.0) + logger.info(f"Successfully initialized Qdrant client for {endpoint_name}") logger.debug("Qdrant connection test successful") + except asyncio.TimeoutError: + logger.error(f"Timeout initializing Qdrant client for {endpoint_name}") + raise except Exception as e: - logger.exception(f"Failed to initialize Qdrant client: {str(e)}") + logger.exception(f"Failed to initialize Qdrant client: {e}") raise -async def get_qdrant_client(endpoint_name=None): +async def get_qdrant_client(endpoint_name: Optional[str] = None) -> AsyncQdrantClient: """Get or initialize Qdrant client.""" endpoint_name = endpoint_name or CONFIG.write_endpoint if endpoint_name not in qdrant_clients: @@ -76,15 +84,17 @@ async def get_qdrant_client(endpoint_name=None): return qdrant_clients[endpoint_name] -def get_collection_name(endpoint_name=None): +def get_collection_name(endpoint_name: Optional[str] = None) -> str: """Get collection name from endpoint config or use default.""" endpoint_name = endpoint_name or CONFIG.write_endpoint + if endpoint_name not in CONFIG.retrieval_endpoints: + raise ValueError(f"Unknown Qdrant endpoint: {endpoint_name}") endpoint_config = CONFIG.retrieval_endpoints[endpoint_name] index_name = endpoint_config.index_name return index_name or "nlweb_collection" -def create_site_filter(site): +def create_site_filter(site) -> Optional[models.Filter]: """Create a Qdrant filter for site filtering.""" if site == "all": return None @@ -101,26 +111,33 @@ def create_site_filter(site): ) -def format_results(search_result): +def format_results(search_result) -> List[List[Any]]: """Format Qdrant search results to match expected API: [url, text_json, name, site].""" results = [] for item in search_result: payload = item.payload - url = payload.get("url", "") - schema = payload.get("schema_json", "") - name = payload.get("name", "") - site_name = payload.get("site", "") - - results.append([url, schema, name, site_name]) - + results.append([ + payload.get("url", ""), + payload.get("schema_json", ""), + payload.get("name", ""), + payload.get("site", ""), + ]) return results -async def search_db(query, site, num_results=50, endpoint_name=None, query_params=None): +async def search_db( + query: str, + site: str, + num_results: int = 50, + endpoint_name: Optional[str] = None, + query_params: Optional[Dict[str, Any]] = None, +) -> List[List[Any]]: """Search Qdrant for records filtered by site and ranked by vector similarity.""" endpoint_name = endpoint_name or CONFIG.write_endpoint + logger.info( - f"Starting Qdrant search - endpoint: {endpoint_name}, site: {site}, num_results: {num_results}" + f"Starting Qdrant search - endpoint: {endpoint_name}, site: {site}, " + f"num_results: {num_results}" ) try: @@ -133,19 +150,30 @@ async def search_db(query, site, num_results=50, endpoint_name=None, query_param collection = get_collection_name(endpoint_name) filter_condition = create_site_filter(site) - search_result = ( - await client.query_points( + # Add timeout to the query_points call + response = await asyncio.wait_for( + client.query_points( collection_name=collection, query=embedding, limit=num_results, with_payload=True, query_filter=filter_condition, - ) - ).points + ), + timeout=15.0 + ) + search_result = response.points + retrieve_time = time.time() - start_retrieve results = format_results(search_result) - retrieve_time = time.time() - start_retrieve + # Standard INFO logging + logger.info( + f"Qdrant search completed: embed_time={embed_time:.2f}s, " + f"retrieve_time={retrieve_time:.2f}s, " + f"total_time={(embed_time + retrieve_time):.2f}s, " + f"results_count={len(results)}, embedding_dim={len(embedding)}" + ) + # Structured context logger.log_with_context( LogLevel.INFO, "Qdrant search completed", @@ -160,8 +188,32 @@ async def search_db(query, site, num_results=50, endpoint_name=None, query_param return results + except asyncio.TimeoutError: + logger.error(f"Qdrant query_points timed out after 15s (endpoint={endpoint_name}, site={site})") + raise + except UnexpectedResponse as e: + logger.warning("Qdrant collection likely missing - did you run indexing first?") + logger.error( + f"Qdrant search failed: {type(e).__name__}: {e} " + f"(endpoint={endpoint_name}, site={site})" + ) + logger.log_with_context( + LogLevel.ERROR, + "Qdrant search failed", + { + "error_type": type(e).__name__, + "error_message": str(e), + "endpoint": endpoint_name, + "site": site, + }, + ) + raise except Exception as e: - logger.exception(f"Error in Qdrant search_db: {str(e)}") + logger.exception(f"Error in Qdrant search_db: {e}") + logger.error( + f"Qdrant search failed: {type(e).__name__}: {e} " + f"(endpoint={endpoint_name}, site={site})" + ) logger.log_with_context( LogLevel.ERROR, "Qdrant search failed", @@ -175,29 +227,34 @@ async def search_db(query, site, num_results=50, endpoint_name=None, query_param raise -async def retrieve_item_with_url(url, endpoint_name=None): +async def retrieve_item_with_url( + url: str, + endpoint_name: Optional[str] = None) -> List[Any]: """Retrieve a specific item by URL from Qdrant database.""" endpoint_name = endpoint_name or CONFIG.write_endpoint - logger.info(f"Retrieving item by URL: {url}") + logger.info(f"Retrieving Qdrant item - url: {url}, endpoint: {endpoint_name}") try: client = await get_qdrant_client(endpoint_name) collection = get_collection_name(endpoint_name) - filter_condition = models.Filter( must=[models.FieldCondition(key="url", match=models.MatchValue(value=url))] ) - points, _offset = await client.scroll( - collection_name=collection, - scroll_filter=filter_condition, - limit=1, - with_payload=True, + response, _offset = await asyncio.wait_for( + client.scroll( + collection_name=collection, + scroll_filter=filter_condition, + limit=1, + with_payload=True, + ), + timeout=10.0 ) + points = response if not points: logger.warning(f"No item found for URL: {url}") - return None + return [] item = points[0] payload = item.payload @@ -209,10 +266,39 @@ async def retrieve_item_with_url(url, endpoint_name=None): ] logger.info(f"Successfully retrieved item for URL: {url}") + logger.log_with_context( + LogLevel.INFO, + "Qdrant item retrieval succeeded", + {"url": url, "endpoint": endpoint_name}, + ) return formatted_result + except asyncio.TimeoutError: + logger.error(f"Qdrant scroll timed out after 10s (url={url}, endpoint={endpoint_name})") + raise + except UnexpectedResponse as e: + logger.warning("Qdrant collection likely missing - did you run indexing first?") + logger.error( + f"Qdrant item retrieval failed: {type(e).__name__}: {e} " + f"(url={url}, endpoint={endpoint_name})" + ) + logger.log_with_context( + LogLevel.ERROR, + "Qdrant item retrieval failed", + { + "error_type": type(e).__name__, + "error_message": str(e), + "url": url, + "endpoint": endpoint_name, + }, + ) + raise except Exception as e: - logger.exception(f"Error retrieving item with URL: {str(e)}") + logger.exception(f"Error retrieving item with URL: {url}: {e}") + logger.error( + f"Qdrant item retrieval failed: {type(e).__name__}: {e} " + f"(url={url}, endpoint={endpoint_name})" + ) logger.log_with_context( LogLevel.ERROR, "Qdrant item retrieval failed",