diff --git a/src/diagnose_vault.py b/src/diagnose_vault.py index 6225359..aa1e782 100644 --- a/src/diagnose_vault.py +++ b/src/diagnose_vault.py @@ -1,5 +1,6 @@ """Diagnose vault files to find empty string issues.""" +import asyncio import os import sys from pathlib import Path @@ -9,61 +10,67 @@ from src.embedder import VoyageEmbedder from src.exceptions import EmbeddingError -vault_path = Path(os.getenv("OBSIDIAN_VAULT_PATH", "/vault")) -embedder = VoyageEmbedder() -print(f"Scanning vault: {vault_path}") -md_files = list(vault_path.rglob("*.md")) -print(f"Found {len(md_files)} markdown files\n") +async def diagnose(): + vault_path = Path(os.getenv("OBSIDIAN_VAULT_PATH", "/vault")) + embedder = VoyageEmbedder() -problematic_files = [] -empty_files = [] -successful = 0 + print(f"Scanning vault: {vault_path}") + md_files = list(vault_path.rglob("*.md")) + print(f"Found {len(md_files)} markdown files\n") -for i, file_path in enumerate(md_files, 1): - try: - with open(file_path, encoding="utf-8") as f: - content = f.read().strip() + problematic_files = [] + empty_files = [] + successful = 0 - rel_path = str(file_path.relative_to(vault_path)) - - if not content: - empty_files.append(rel_path) - print(f"[{i}/{len(md_files)}] EMPTY: {rel_path}") - continue - - if len(content) < 5: - print(f"[{i}/{len(md_files)}] VERY SHORT ({len(content)} chars): {rel_path}") - - # Try to embed single file + for i, file_path in enumerate(md_files, 1): try: - embedding = embedder.embed(content, input_type="document") - successful += 1 - if i % 10 == 0: - print(f"[{i}/{len(md_files)}] āœ“ Progress: {successful} successful") - - except EmbeddingError as e: - problematic_files.append((rel_path, f"Embedding failed: {e}")) - print(f"[{i}/{len(md_files)}] FAILED: {rel_path} - {e}") - - except Exception as e: - problematic_files.append((rel_path, str(e))) - print(f"[{i}/{len(md_files)}] ERROR: {rel_path} - {e}") - -print(f"\n{'=' * 60}") -print("Summary:") -print(f" Successful: {successful}/{len(md_files)}") -print(f" Empty files: {len(empty_files)}") -print(f" Problematic: {len(problematic_files)}") - -if empty_files: - print(f"\nEmpty files ({len(empty_files)}):") - for path in empty_files[:10]: - print(f" - {path}") - if len(empty_files) > 10: - print(f" ... and {len(empty_files) - 10} more") - -if problematic_files: - print(f"\nProblematic files ({len(problematic_files)}):") - for path, error in problematic_files[:10]: - print(f" - {path}: {error}") + with open(file_path, encoding="utf-8") as f: + content = f.read().strip() + + rel_path = str(file_path.relative_to(vault_path)) + + if not content: + empty_files.append(rel_path) + print(f"[{i}/{len(md_files)}] EMPTY: {rel_path}") + continue + + if len(content) < 5: + print(f"[{i}/{len(md_files)}] VERY SHORT " f"({len(content)} chars): {rel_path}") + + # Try to embed single file + try: + await embedder.embed(content, input_type="document") + successful += 1 + if i % 10 == 0: + print(f"[{i}/{len(md_files)}] " f"Progress: {successful} successful") + + except EmbeddingError as e: + problematic_files.append((rel_path, f"Embedding failed: {e}")) + print(f"[{i}/{len(md_files)}] FAILED: {rel_path} - {e}") + + except Exception as e: + problematic_files.append((rel_path, str(e))) + print(f"[{i}/{len(md_files)}] ERROR: {rel_path} - {e}") + + print(f"\n{'=' * 60}") + print("Summary:") + print(f" Successful: {successful}/{len(md_files)}") + print(f" Empty files: {len(empty_files)}") + print(f" Problematic: {len(problematic_files)}") + + if empty_files: + print(f"\nEmpty files ({len(empty_files)}):") + for path in empty_files[:10]: + print(f" - {path}") + if len(empty_files) > 10: + print(f" ... and {len(empty_files) - 10} more") + + if problematic_files: + print(f"\nProblematic files ({len(problematic_files)}):") + for path, error in problematic_files[:10]: + print(f" - {path}: {error}") + + +if __name__ == "__main__": + asyncio.run(diagnose()) diff --git a/src/embedder.py b/src/embedder.py index 873493f..3764893 100644 --- a/src/embedder.py +++ b/src/embedder.py @@ -146,7 +146,7 @@ def chunk_text(self, text: str, chunk_size: int = 2000, overlap: int = 0) -> lis ) return chunks - def embed_with_chunks( + async def embed_with_chunks( self, text: str, chunk_size: int = 2000, input_type: str = "document" ) -> tuple[list[list[float]], int]: """ @@ -174,7 +174,7 @@ def embed_with_chunks( # If under limit, embed whole if estimated_tokens < 30000: try: - embedding = self.embed(text, input_type=input_type) + embedding = await self.embed(text, input_type=input_type) return ([embedding], 1) except EmbeddingError as e: if _is_token_limit_error(e): @@ -206,12 +206,12 @@ def embed_with_chunks( while i < len(chunks): chunk_batch = chunks[i : i + batch_size] - # Rate limit - self._rate_limit_sync() + # Rate limit (async - non-blocking) + await self._rate_limit_async() try: - # Embed this batch of chunks with context - result = self._call_api_with_retry( + # Embed this batch of chunks with context (runs in thread pool) + result = await self._call_api_with_timeout( self.client.contextualized_embed, inputs=[chunk_batch], # One document's chunks model=self.model, @@ -230,7 +230,7 @@ def embed_with_chunks( # Halve batch size and retry this batch batch_size = max(1, batch_size // 2) logger.warning( - f"Batch too large for token limit, reducing to {batch_size} chunks" + f"Batch too large for token limit, " f"reducing to {batch_size} chunks" ) continue # Retry same position with smaller batch raise @@ -242,7 +242,9 @@ def embed_with_chunks( error_msg = _redact_sensitive(str(e)) logger.error(f"Chunked embedding failed: {error_msg}", exc_info=True) raise EmbeddingError( - f"Failed to embed chunked text: {error_msg}", text_preview=text[:100], cause=e + f"Failed to embed chunked text: {error_msg}", + text_preview=text[:100], + cause=e, ) from e def _load_cache_index(self) -> dict: @@ -315,7 +317,8 @@ def _call_api_with_retry(self, api_func, *args, **kwargs): # Exponential backoff: 2^attempt seconds (1, 2, 4, ...) backoff = 2 ** (attempt + 1) logger.warning( - f"Rate limited, retrying in {backoff}s (attempt {attempt + 1}/{self.max_retries})" + f"Rate limited, retrying in {backoff}s " + f"(attempt {attempt + 1}/{self.max_retries})" ) time.sleep(backoff) elif attempt < self.max_retries - 1: @@ -327,10 +330,13 @@ def _call_api_with_retry(self, api_func, *args, **kwargs): ) time.sleep(backoff) else: - logger.error(f"API call failed after {self.max_retries} attempts: {error_msg}") + logger.error( + f"API call failed after {self.max_retries} attempts: " f"{error_msg}" + ) raise EmbeddingError( - f"API call failed after {self.max_retries} attempts: {_redact_sensitive(str(last_error))}", + f"API call failed after {self.max_retries} attempts: " + f"{_redact_sensitive(str(last_error))}", cause=last_error, ) @@ -349,7 +355,7 @@ async def _call_api_with_timeout(self, api_func, *args, **kwargs): Raises: EmbeddingError: If timeout or API error occurs """ - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() try: result = await asyncio.wait_for( @@ -363,7 +369,9 @@ async def _call_api_with_timeout(self, api_func, *args, **kwargs): except TimeoutError as e: raise EmbeddingError(f"API call timed out after {self.api_timeout}s", cause=e) from e - def embed(self, text: str, input_type: str = "document", use_cache: bool = True) -> list[float]: + async def embed( + self, text: str, input_type: str = "document", use_cache: bool = True + ) -> list[float]: """ Generate embedding for a single text. @@ -378,14 +386,14 @@ def embed(self, text: str, input_type: str = "document", use_cache: bool = True) Raises: EmbeddingError: If embedding generation fails """ - results = self.embed_batch([text], input_type, use_cache) + results = await self.embed_batch([text], input_type, use_cache) if not results or results[0] is None: raise EmbeddingError("Failed to generate embedding for text", text_preview=text[:100]) return results[0] - def embed_batch( + async def embed_batch( self, texts: list[str], input_type: str = "document", use_cache: bool = True ) -> list[list[float]]: """ @@ -435,8 +443,8 @@ def embed_batch( for i in range(0, len(texts_to_embed), self.batch_size): batch = texts_to_embed[i : i + self.batch_size] - # Rate limiting - self._rate_limit_sync() + # Rate limiting (async - non-blocking) + await self._rate_limit_async() try: # Filter out empty strings (Voyage API rejects them) @@ -459,8 +467,8 @@ def embed_batch( # Each note is a single-element list (whole note, not chunked) nested_inputs = [[text] for text in non_empty] - # Call Voyage API with retry and error handling - result = self._call_api_with_retry( + # Call Voyage API with timeout (runs in thread pool, non-blocking) + result = await self._call_api_with_timeout( self.client.contextualized_embed, inputs=nested_inputs, model=self.model, @@ -473,7 +481,8 @@ def embed_batch( # Since we pass whole notes as single chunks, we take [0] api_embeddings = [doc_result.embeddings[0] for doc_result in result.results] - # Map back to original batch positions (accounting for None placeholders) + # Map back to original batch positions + # (accounting for None placeholders) embedding_idx = 0 for text in filtered_batch: if text is None: @@ -482,7 +491,7 @@ def embed_batch( new_embeddings.append(api_embeddings[embedding_idx]) embedding_idx += 1 - # Cache results using JSON (safer than pickle) + # Cache results using JSON (safer than other formats) if use_cache: # Cache only non-None embeddings for text, embedding in zip(non_empty, api_embeddings, strict=False): diff --git a/src/file_watcher.py b/src/file_watcher.py index 78aa37e..a03fe76 100644 --- a/src/file_watcher.py +++ b/src/file_watcher.py @@ -429,7 +429,7 @@ async def _reindex_file(self, file_path: str): # Generate embedding(s) with automatic chunking for large notes try: - embeddings_list, total_chunks = self.embedder.embed_with_chunks( + embeddings_list, total_chunks = await self.embedder.embed_with_chunks( content, chunk_size=2000, input_type="document" ) except EmbeddingError as e: diff --git a/src/indexer.py b/src/indexer.py index b2d9679..ef971df 100644 --- a/src/indexer.py +++ b/src/indexer.py @@ -165,7 +165,7 @@ async def index_vault(vault_path: str, batch_size: int = 100): for note_data in valid_notes: # embed_with_chunks handles both small (whole) and large (chunked) notes try: - embeddings_list, total_chunks = embedder.embed_with_chunks( + embeddings_list, total_chunks = await embedder.embed_with_chunks( note_data["content"], chunk_size=2000, # oachatbot standard input_type="document", diff --git a/src/security_utils.py b/src/security_utils.py index 70fdb38..0086fff 100644 --- a/src/security_utils.py +++ b/src/security_utils.py @@ -89,8 +89,11 @@ def validate_vault_path(user_path: str, vault_root: str) -> str: raise SecurityError(f"Path traversal detected: {user_path}") # 4. Resolve against vault root and ensure it stays within bounds - vault_root_resolved = Path(vault_root).resolve() - full_path = (vault_root_resolved / sanitized).resolve() + try: + vault_root_resolved = Path(vault_root).resolve() + full_path = (vault_root_resolved / sanitized).resolve() + except (OSError, RuntimeError) as e: + raise SecurityError(f"Path resolution failed for '{user_path}': {e}") from e # Check if resolved path is still within vault try: diff --git a/src/server.py b/src/server.py index e7a6bcd..8d1cd08 100644 --- a/src/server.py +++ b/src/server.py @@ -103,7 +103,7 @@ async def initialize_server(): ) # Start file watching first (creates event_handler) - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() vault_watcher.start(loop) # Run startup scan to catch files changed while offline @@ -184,7 +184,9 @@ async def list_tools() -> list[Tool]: ), Tool( name="get_connection_graph", - description="Build multi-hop connection graph using BFS traversal to discover relationships", + description=( + "Build multi-hop connection graph using BFS traversal" " to discover relationships" + ), inputSchema={ "type": "object", "properties": { @@ -231,7 +233,7 @@ async def list_tools() -> list[Tool]: }, "threshold": { "type": "number", - "description": "Similarity threshold for counting connections (0.0-1.0)", + "description": ("Similarity threshold for counting connections (0.0-1.0)"), "default": 0.5, "minimum": 0.0, "maximum": 1.0, @@ -260,7 +262,7 @@ async def list_tools() -> list[Tool]: }, "threshold": { "type": "number", - "description": "Similarity threshold for counting connections (0.0-1.0)", + "description": ("Similarity threshold for counting connections (0.0-1.0)"), "default": 0.5, "minimum": 0.0, "maximum": 1.0, @@ -308,10 +310,15 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[dict[str, Any] # Generate query embedding try: - query_embedding = ctx.embedder.embed(query, input_type="query") + query_embedding = await ctx.embedder.embed(query, input_type="query") except EmbeddingError as e: logger.error(f"Query embedding failed: {e}", exc_info=True) - return [{"type": "text", "text": f"Error: Failed to generate query embedding: {e}"}] + return [ + { + "type": "text", + "text": f"Error: Failed to generate query embedding: {e}", + } + ] # Search results = await ctx.store.search(query_embedding, limit, threshold) @@ -322,7 +329,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[dict[str, Any] snippet = ( result.content[:200] + "..." if len(result.content) > 200 else result.content ) - response += f"{i}. **{result.title}** (similarity: {result.similarity:.3f})\n" + response += f"{i}. **{result.title}** " f"(similarity: {result.similarity:.3f})\n" response += f" Path: `{result.path}`\n" response += f" {snippet}\n\n" @@ -342,7 +349,8 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[dict[str, Any] # SECURITY: Validate note_path before processing note_path = validate_note_path_parameter( - validated["note_path"], vault_path=os.getenv("OBSIDIAN_VAULT_PATH", "/vault") + validated["note_path"], + vault_path=os.getenv("OBSIDIAN_VAULT_PATH", "/vault"), ) limit = validated["limit"] threshold = validated["threshold"] @@ -353,7 +361,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[dict[str, Any] # Format results response = f"Notes similar to `{note_path}`:\n\n" for i, result in enumerate(results, 1): - response += f"{i}. **{result.title}** (similarity: {result.similarity:.3f})\n" + response += f"{i}. **{result.title}** " f"(similarity: {result.similarity:.3f})\n" response += f" Path: `{result.path}`\n\n" return [{"type": "text", "text": response}] @@ -375,7 +383,8 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[dict[str, Any] # SECURITY: Validate note_path before processing note_path = validate_note_path_parameter( - validated["note_path"], vault_path=os.getenv("OBSIDIAN_VAULT_PATH", "/vault") + validated["note_path"], + vault_path=os.getenv("OBSIDIAN_VAULT_PATH", "/vault"), ) depth = validated["depth"] max_per_level = validated["max_per_level"] @@ -389,7 +398,10 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[dict[str, Any] # Format results response = f"# Connection Graph: {graph['root']['title']}\n\n" response += f"**Starting note:** `{graph['root']['path']}`\n" - response += f"**Network size:** {graph['stats']['total_nodes']} nodes, {graph['stats']['total_edges']} edges\n\n" + response += ( + f"**Network size:** {graph['stats']['total_nodes']} nodes, " + f"{graph['stats']['total_edges']} edges\n\n" + ) # Group nodes by level nodes_by_level = {} @@ -407,10 +419,14 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[dict[str, Any] if node["parent_path"]: # Find edge to get similarity edge = next( - (e for e in graph["edges"] if e["target"] == node["path"]), None + (e for e in graph["edges"] if e["target"] == node["path"]), + None, ) if edge: - response += f" Connected from: `{node['parent_path']}` (similarity: {edge['similarity']:.3f})\n" + response += ( + f" Connected from: `{node['parent_path']}` " + f"(similarity: {edge['similarity']:.3f})\n" + ) return [{"type": "text", "text": response}] @@ -437,12 +453,17 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[dict[str, Any] # Format results if not hubs: - response = f"No hub notes found with >={min_connections} connections at threshold {threshold}" + response = ( + f"No hub notes found with >={min_connections} connections " + f"at threshold {threshold}" + ) else: response = "# Hub Notes (Highly Connected)\n\n" - response += f"Found {len(hubs)} notes with >={min_connections} connections:\n\n" + response += f"Found {len(hubs)} notes with " f">={min_connections} connections:\n\n" for i, hub in enumerate(hubs, 1): - response += f"{i}. **{hub['title']}** ({hub['connection_count']} connections)\n" + response += ( + f"{i}. **{hub['title']}** " f"({hub['connection_count']} connections)\n" + ) response += f" Path: `{hub['path']}`\n\n" return [{"type": "text", "text": response}] @@ -467,13 +488,16 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[dict[str, Any] # Format results if not orphans: - response = f"No orphaned notes found with <={max_connections} connections" + response = f"No orphaned notes found with " f"<={max_connections} connections" else: response = "# Orphaned Notes (Isolated)\n\n" - response += f"Found {len(orphans)} notes with <={max_connections} connections:\n\n" + response += ( + f"Found {len(orphans)} notes with " f"<={max_connections} connections:\n\n" + ) for i, orphan in enumerate(orphans, 1): response += ( - f"{i}. **{orphan['title']}** ({orphan['connection_count']} connections)\n" + f"{i}. **{orphan['title']}** " + f"({orphan['connection_count']} connections)\n" ) response += f" Path: `{orphan['path']}`\n" if orphan["modified_at"]: diff --git a/src/vector_store.py b/src/vector_store.py index 2dcc976..321d1a5 100644 --- a/src/vector_store.py +++ b/src/vector_store.py @@ -120,7 +120,8 @@ async def initialize(self) -> None: # Verify notes table exists table_exists = await conn.fetchval( - "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name = 'notes')" + "SELECT EXISTS(SELECT 1 FROM information_schema.tables " + "WHERE table_name = 'notes')" ) if not table_exists: logger.warning("Notes table does not exist yet (will be created by schema.sql)") @@ -132,6 +133,10 @@ async def initialize(self) -> None: except Exception as e: raise VectorStoreError(f"PostgreSQL initialization failed: {e}") from e + async def _with_timeout(self, coro, timeout=10.0): + """Execute a coroutine with a timeout.""" + return await asyncio.wait_for(coro, timeout=timeout) + async def _setup_connection(self, conn): """Setup each connection with pgvector support.""" await register_vector(conn) @@ -163,7 +168,8 @@ async def search( if len(query_embedding) != EMBEDDING_DIMENSIONS: raise VectorStoreError( - f"Query embedding must be {EMBEDDING_DIMENSIONS} dimensions, got {len(query_embedding)}" + f"Query embedding must be {EMBEDDING_DIMENSIONS} dimensions, " + f"got {len(query_embedding)}" ) try: @@ -188,7 +194,8 @@ async def search( async with self.pool.acquire() as conn: start_time = time.time() rows = await asyncio.wait_for( - conn.fetch(query, query_embedding, distance_threshold, limit), timeout=5.0 + conn.fetch(query, query_embedding, distance_threshold, limit), + timeout=5.0, ) query_time_ms = (time.time() - start_time) * 1000 @@ -230,8 +237,9 @@ async def get_similar_notes( try: async with self.pool.acquire() as conn: # Fetch source note's embedding - source_embedding = await conn.fetchval( - "SELECT embedding FROM notes WHERE path = $1", note_path + source_embedding = await self._with_timeout( + conn.fetchval("SELECT embedding FROM notes WHERE path = $1", note_path), + timeout=5.0, ) if source_embedding is None: @@ -248,6 +256,8 @@ async def get_similar_notes( results = [r for r in results if r.path != note_path] return results[:limit] + except TimeoutError as e: + raise VectorStoreError("Similar notes search timed out") from e except Exception as e: raise VectorStoreError(f"Similar notes search failed: {e}") from e @@ -273,7 +283,8 @@ async def upsert_note(self, note: Note) -> bool: try: query = """ - INSERT INTO notes (path, title, content, embedding, modified_at, file_size_bytes, chunk_index, total_chunks, last_indexed_at) + INSERT INTO notes (path, title, content, embedding, modified_at, + file_size_bytes, chunk_index, total_chunks, last_indexed_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, CURRENT_TIMESTAMP) ON CONFLICT (path, chunk_index) DO UPDATE SET title = EXCLUDED.title, @@ -287,21 +298,26 @@ async def upsert_note(self, note: Note) -> bool: """ async with self.pool.acquire() as conn: - await conn.execute( - query, - note.path, - note.title, - note.content, - note.embedding, - note.modified_at, - note.file_size_bytes, - note.chunk_index, - note.total_chunks, + await self._with_timeout( + conn.execute( + query, + note.path, + note.title, + note.content, + note.embedding, + note.modified_at, + note.file_size_bytes, + note.chunk_index, + note.total_chunks, + ), + timeout=10.0, ) logger.debug(f"Upserted note: {note.path}") return True + except TimeoutError as e: + raise VectorStoreError("Note upsert timed out") from e except Exception as e: raise VectorStoreError(f"Note upsert failed: {e}") from e @@ -331,7 +347,8 @@ async def upsert_batch(self, notes: list[Note]) -> int: try: query = """ - INSERT INTO notes (path, title, content, embedding, modified_at, file_size_bytes, chunk_index, total_chunks, last_indexed_at) + INSERT INTO notes (path, title, content, embedding, modified_at, + file_size_bytes, chunk_index, total_chunks, last_indexed_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, CURRENT_TIMESTAMP) ON CONFLICT (path, chunk_index) DO UPDATE SET title = EXCLUDED.title, @@ -359,11 +376,16 @@ async def upsert_batch(self, notes: list[Note]) -> int: async with self.pool.acquire() as conn: async with conn.transaction(): - await conn.executemany(query, batch_data) + await self._with_timeout( + conn.executemany(query, batch_data), + timeout=30.0, + ) logger.info(f"Batch upserted {len(notes)} notes") return len(notes) + except TimeoutError as e: + raise VectorStoreError("Batch upsert timed out") from e except Exception as e: raise VectorStoreError(f"Batch upsert failed: {e}") from e @@ -374,7 +396,12 @@ async def get_note_count(self) -> int: try: async with self.pool.acquire() as conn: - return await conn.fetchval("SELECT COUNT(*) FROM notes") + return await self._with_timeout( + conn.fetchval("SELECT COUNT(*) FROM notes"), + timeout=5.0, + ) + except TimeoutError as e: + raise VectorStoreError("Count query timed out") from e except Exception as e: raise VectorStoreError(f"Count query failed: {e}") from e @@ -397,11 +424,17 @@ async def delete_notes_by_paths(self, paths: list[str]) -> int: try: async with self.pool.acquire() as conn: # Use RETURNING to get distinct note count (not chunk count) - rows = await conn.fetch( - "DELETE FROM notes WHERE path = ANY($1) RETURNING path", paths + rows = await self._with_timeout( + conn.fetch( + "DELETE FROM notes WHERE path = ANY($1) RETURNING path", + paths, + ), + timeout=10.0, ) # Count distinct paths (a chunked note has multiple rows with same path) return len({row["path"] for row in rows}) + except TimeoutError as e: + raise VectorStoreError("Delete operation timed out") from e except Exception as e: raise VectorStoreError(f"Delete failed: {e}") from e @@ -417,8 +450,13 @@ async def get_all_paths(self) -> list[str]: try: async with self.pool.acquire() as conn: - rows = await conn.fetch("SELECT DISTINCT path FROM notes") + rows = await self._with_timeout( + conn.fetch("SELECT DISTINCT path FROM notes"), + timeout=10.0, + ) return [row["path"] for row in rows] + except TimeoutError as e: + raise VectorStoreError("Get paths query timed out") from e except Exception as e: raise VectorStoreError(f"Get paths failed: {e}") from e diff --git a/tests/conftest.py b/tests/conftest.py index bf632a7..01ee332 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -72,9 +72,9 @@ def mock_embedder(): # Return dummy 1024-dim vector (successful embedding) dummy_embedding = [0.1] * 1024 - embedder.embed = MagicMock(return_value=dummy_embedding) - embedder.embed_batch = MagicMock(return_value=[dummy_embedding]) - embedder.embed_with_chunks = MagicMock(return_value=([dummy_embedding], 1)) + embedder.embed = AsyncMock(return_value=dummy_embedding) + embedder.embed_batch = AsyncMock(return_value=[dummy_embedding]) + embedder.embed_with_chunks = AsyncMock(return_value=([dummy_embedding], 1)) embedder.chunk_text = MagicMock(return_value=["chunk1"]) embedder.get_cache_stats = MagicMock( return_value={ diff --git a/tests/test_error_paths.py b/tests/test_error_paths.py index 9930ee3..7fcd3b3 100644 --- a/tests/test_error_paths.py +++ b/tests/test_error_paths.py @@ -29,7 +29,7 @@ async def test_search_notes_embedding_failure(self, server_context): from src.server import call_tool # Configure mock to raise EmbeddingError (simulating Voyage API failure) - server_context.embedder.embed = MagicMock( + server_context.embedder.embed = AsyncMock( side_effect=EmbeddingError("Voyage API rate limited", text_preview="test query") ) @@ -61,7 +61,7 @@ async def timeout_side_effect(*args, **kwargs): raise TimeoutError("Connection pool exhausted") server_context.store.search = AsyncMock(side_effect=timeout_side_effect) - server_context.embedder.embed = MagicMock(return_value=[0.1] * 1024) + server_context.embedder.embed = AsyncMock(return_value=[0.1] * 1024) # Call search_notes result = await call_tool("search_notes", {"query": "test", "limit": 10, "threshold": 0.5}) diff --git a/tests/test_indexer.py b/tests/test_indexer.py index 447ab8e..a1f0a2d 100644 --- a/tests/test_indexer.py +++ b/tests/test_indexer.py @@ -11,7 +11,7 @@ import sys from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -181,7 +181,7 @@ async def test_index_vault_handles_large_notes_with_chunking(tmp_path, mock_stor (vault / "large.md").write_text(large_content) # Mock embedder to return multiple chunks - mock_embedder.embed_with_chunks = MagicMock( + mock_embedder.embed_with_chunks = AsyncMock( return_value=([[0.1] * 1024, [0.2] * 1024, [0.3] * 1024], 3) # 3 chunks # total_chunks ) mock_embedder.chunk_text = MagicMock( diff --git a/tests/test_mcp_tools_integration.py b/tests/test_mcp_tools_integration.py index 4d454a1..95a488f 100644 --- a/tests/test_mcp_tools_integration.py +++ b/tests/test_mcp_tools_integration.py @@ -79,7 +79,7 @@ async def test_mcp_tools_integration(tmp_path): ] texts = [content for _, _, content in test_notes] - embeddings = embedder.embed_batch(texts, input_type="document") + embeddings = await embedder.embed_batch(texts, input_type="document") notes = [] for (path, title, content), embedding in zip(test_notes, embeddings, strict=False): @@ -101,7 +101,7 @@ async def test_mcp_tools_integration(tmp_path): print("\nšŸ” Test 1: search_notes") print("-" * 60) - query_emb = embedder.embed("neural networks deep learning", input_type="query") + query_emb = await embedder.embed("neural networks deep learning", input_type="query") start = time.time() results = await store.search(query_emb, limit=5, threshold=0.0) latency_ms = (time.time() - start) * 1000 diff --git a/tests/test_tools.py b/tests/test_tools.py index f504a9c..3ce3913 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -93,7 +93,7 @@ async def setup_test_data(): # Generate embeddings and insert texts = [content for _, _, content in test_notes_content] - embeddings = embedder.embed_batch(texts, input_type="document") + embeddings = await embedder.embed_batch(texts, input_type="document") notes = [] for (path, title, content), embedding in zip(test_notes_content, embeddings, strict=False): @@ -123,7 +123,7 @@ async def test_search_notes_similarity_range(setup_test_data): """Ensure similarity scores are in [0.0, 1.0] range.""" store, embedder = setup_test_data - query_embedding = embedder.embed("machine learning algorithms", input_type="query") + query_embedding = await embedder.embed("machine learning algorithms", input_type="query") results = await store.search(query_embedding, limit=10, threshold=0.0) for result in results: @@ -137,7 +137,7 @@ async def test_search_notes_performance(setup_test_data): """Verify search latency < 500ms.""" store, embedder = setup_test_data - query_embedding = embedder.embed("neural networks", input_type="query") + query_embedding = await embedder.embed("neural networks", input_type="query") start = time.time() results = await store.search(query_embedding, limit=10) @@ -151,7 +151,7 @@ async def test_search_notes_threshold(setup_test_data): """Verify threshold filtering works.""" store, embedder = setup_test_data - query_embedding = embedder.embed("machine learning", input_type="query") + query_embedding = await embedder.embed("machine learning", input_type="query") results = await store.search(query_embedding, limit=10, threshold=0.2) for result in results: