From 716699d29db43e8eb7b67774864dfe32ccff0b27 Mon Sep 17 00:00:00 2001 From: drewburchfield Date: Fri, 3 Apr 2026 12:43:49 -0500 Subject: [PATCH 1/3] Fix hub notes race condition: await refresh instead of fire-and-forget _ensure_fresh_counts() was using asyncio.create_task() to refresh connection counts in the background, causing get_hubs/get_orphans to query stale data on the first call. Now awaits the refresh inline so counts are always fresh before the query runs. Changes: - Await _refresh_all_counts() directly instead of create_task() - Wait for in-progress refresh instead of skipping when lock is held - Remove _handle_refresh_error callback (errors propagate via await) - Update tests to match new wait-instead-of-skip behavior --- src/hub_analyzer.py | 39 +++++++++++--------------------------- tests/test_hub_analyzer.py | 37 ++++++++++++++++++++++-------------- 2 files changed, 34 insertions(+), 42 deletions(-) diff --git a/src/hub_analyzer.py b/src/hub_analyzer.py index 3dc7a42..36caf5b 100644 --- a/src/hub_analyzer.py +++ b/src/hub_analyzer.py @@ -21,12 +21,12 @@ class HubAnalyzer: Thread-Safety: - Uses asyncio.Lock for refresh operations - Multiple concurrent calls to get_hub_notes/get_orphaned_notes: safe - - Only ONE vault refresh runs at a time (others skip if in progress) + - Only ONE vault refresh runs at a time (others wait for completion) Performance: - Refresh is O(N²) where N = number of notes - Triggered when >50% of notes have stale connection_count - - Runs in background (non-blocking for queries) + - Awaited inline so counts are fresh before queries return """ def __init__(self, store: PostgreSQLVectorStore): @@ -39,22 +39,6 @@ def __init__(self, store: PostgreSQLVectorStore): self.store = store self._refresh_lock = asyncio.Lock() # Replaces refresh_in_progress boolean - def _handle_refresh_error(self, task: asyncio.Task): - """ - Error callback for background refresh tasks. - - Logs errors from background connection count refresh without crashing. - Non-fatal - hub/orphan queries will still work with stale counts. - - Args: - task: Completed asyncio Task to check for errors - """ - try: - task.result() # Raises exception if task failed - except Exception as e: - logger.error(f"Background connection count refresh failed: {e}", exc_info=True) - # Non-fatal - queries will work with stale counts - async def get_hub_notes( self, min_connections: int = 10, threshold: float = 0.5, limit: int = 20 ) -> list[dict]: @@ -156,18 +140,20 @@ async def get_orphaned_notes( async def _ensure_fresh_counts(self, threshold: float): """ - Ensure connection counts are fresh (or trigger background refresh). + Ensure connection counts are fresh before querying. Checks if any notes have stale connection_count (last_indexed_at old). - Triggers background refresh if needed. + Awaits refresh inline so counts are ready when callers query. Thread-Safety: - - Uses non-blocking lock check to avoid queueing multiple refreshes - - If refresh already running, skips check (other request will handle it) + - If refresh already running, waits for it to complete + - Uses _refresh_lock to prevent duplicate refreshes """ - # Non-blocking lock check - if refresh already running, skip + # If refresh already running, wait for it to complete if self._refresh_lock.locked(): - logger.debug("Refresh already in progress, skipping") + logger.debug("Refresh already in progress, waiting...") + async with self._refresh_lock: + pass # Wait for the running refresh to finish return try: @@ -184,10 +170,7 @@ async def _ensure_fresh_counts(self, threshold: float): logger.info( f"{stale_count}/{total_count} notes have stale counts, refreshing..." ) - # Trigger background refresh with error handling - task = asyncio.create_task(self._refresh_all_counts(threshold)) - task.add_done_callback(self._handle_refresh_error) - logger.debug("Scheduled background refresh task") + await self._refresh_all_counts(threshold) except Exception as e: logger.warning(f"Failed to check count freshness: {e}") diff --git a/tests/test_hub_analyzer.py b/tests/test_hub_analyzer.py index 0db24f1..4ac3b08 100644 --- a/tests/test_hub_analyzer.py +++ b/tests/test_hub_analyzer.py @@ -146,8 +146,8 @@ async def tracked_refresh(threshold): @pytest.mark.asyncio -async def test_ensure_fresh_counts_skips_if_locked(mock_store): - """Test that _ensure_fresh_counts skips check if refresh already running.""" +async def test_ensure_fresh_counts_waits_if_locked(mock_store): + """Test that _ensure_fresh_counts waits for lock when refresh already running.""" analyzer = HubAnalyzer(mock_store) # Mock database @@ -164,13 +164,26 @@ async def __aexit__(self, *args): mock_store.pool = MagicMock() mock_store.pool.acquire = MagicMock(return_value=MockAcquire()) - # Acquire lock to simulate refresh in progress - async with analyzer._refresh_lock: - # While locked, ensure_fresh_counts should skip immediately - await analyzer._ensure_fresh_counts(0.5) + # Acquire lock in a background task to simulate refresh in progress, + # then release it after a short delay so _ensure_fresh_counts can proceed + lock_acquired = asyncio.Event() - # Should not have queried database (skipped due to lock) - mock_conn.fetchval.assert_not_called() + async def hold_lock_briefly(): + async with analyzer._refresh_lock: + lock_acquired.set() + await asyncio.sleep(0.05) # Hold lock briefly + + hold_task = asyncio.create_task(hold_lock_briefly()) + await lock_acquired.wait() + + # While locked, ensure_fresh_counts should wait (not skip) + await analyzer._ensure_fresh_counts(0.5) + + # After waiting for lock release, it should return without querying + # (the wait-and-return path doesn't query the database) + mock_conn.fetchval.assert_not_called() + + await hold_task @pytest.mark.asyncio @@ -237,15 +250,11 @@ async def __aexit__(self, *args): # Mock refresh analyzer._refresh_all_counts = AsyncMock() - # Check freshness + # Check freshness - refresh is now awaited inline await analyzer._ensure_fresh_counts(0.5) - # Allow background task to be scheduled - await asyncio.sleep(0.1) - # Should have triggered refresh (60% > 50% threshold) - # Note: refresh runs via asyncio.create_task, so we check it was called - # The actual scheduling depends on event loop timing + analyzer._refresh_all_counts.assert_called_once_with(0.5) @pytest.mark.asyncio From 77f57d36db24b68ecd7cf4380655033e03fb7f38 Mon Sep 17 00:00:00 2001 From: drewburchfield Date: Fri, 3 Apr 2026 12:48:50 -0500 Subject: [PATCH 2/3] Address review: fix TOCTOU race, update docs and race condition test - Move staleness check inside _refresh_lock to prevent duplicate refreshes - Extract _do_refresh (no lock) from _refresh_all_counts (with lock) - Update CONCURRENCY.md to reflect wait-not-skip behavior - Rewrite race condition test to verify exactly 1 refresh with 20 callers - Fix hub_analyzer tests to match new atomic check+refresh pattern --- docs/CONCURRENCY.md | 22 ++--- src/hub_analyzer.py | 179 +++++++++++++++++----------------- tests/test_hub_analyzer.py | 12 +-- tests/test_race_conditions.py | 113 ++++++++------------- 4 files changed, 139 insertions(+), 187 deletions(-) diff --git a/docs/CONCURRENCY.md b/docs/CONCURRENCY.md index 420fce0..fc737e9 100644 --- a/docs/CONCURRENCY.md +++ b/docs/CONCURRENCY.md @@ -56,7 +56,7 @@ class ObsidianFileWatcher: **Thread-Safety Guarantees:** - ✅ Concurrent calls to `get_hub_notes()`: Safe (read-only queries) - ✅ Concurrent calls to `get_orphaned_notes()`: Safe (read-only queries) -- ✅ Multiple refresh requests: Only ONE refresh runs (others skip) +- ✅ Multiple refresh requests: Only ONE refresh runs (others wait for it) **Synchronization Mechanism**: Global async lock for refresh operations @@ -71,22 +71,16 @@ class HubAnalyzer: | Scenario | Behavior | Performance | |----------|----------|-------------| | 20 concurrent `get_hub_notes()` calls | All execute in parallel (read-only) | Optimal | -| 20 concurrent requests trigger refresh | Only 1 refresh runs, others skip | Optimal | -| Refresh running, new `get_hub_notes()` call | Query executes immediately (no blocking) | Optimal | +| 20 concurrent requests trigger refresh | 1 refresh runs, others wait then re-check | Correct | +| Refresh running, new `get_hub_notes()` call | Waits for refresh, then queries fresh data | Correct | **Refresh Logic:** ```python async def _ensure_fresh_counts(self, threshold): - if self._refresh_lock.locked(): # Non-blocking check - return # Skip if refresh already running - - # Check staleness, schedule refresh if needed - asyncio.create_task(self._refresh_all_counts(threshold)) - -async def _refresh_all_counts(self, threshold): - async with self._refresh_lock: # Blocking - waits for lock - # Scan entire vault, update connection_count for all notes - ... + async with self._refresh_lock: # Staleness check + refresh are atomic + # Check staleness inside lock to prevent TOCTOU races + if stale_count / total_count > 0.5: + await self._do_refresh(threshold) # Inline, not fire-and-forget ``` **Lock Granularity**: Coarse-grained (one lock for entire refresh operation) @@ -175,7 +169,7 @@ pytest tests/test_race_conditions.py::test_hub_analyzer_concurrent_refresh_race |------|------------------|-------------------| | File watcher same file | 10 rapid edits | 1 re-index | | File watcher different files | 50 concurrent files | 50 parallel re-indexes | -| Hub analyzer refresh | 20 concurrent requests | 1 refresh execution | +| Hub analyzer refresh | 20 concurrent requests | 1 refresh, others wait | --- diff --git a/src/hub_analyzer.py b/src/hub_analyzer.py index 36caf5b..271a3f1 100644 --- a/src/hub_analyzer.py +++ b/src/hub_analyzer.py @@ -142,42 +142,44 @@ async def _ensure_fresh_counts(self, threshold: float): """ Ensure connection counts are fresh before querying. - Checks if any notes have stale connection_count (last_indexed_at old). - Awaits refresh inline so counts are ready when callers query. + Checks staleness and refreshes inline so counts are ready when callers query. + All logic runs inside the lock to prevent TOCTOU races and duplicate refreshes. Thread-Safety: - - If refresh already running, waits for it to complete - - Uses _refresh_lock to prevent duplicate refreshes + - Staleness check and refresh are atomic (both inside _refresh_lock) + - Concurrent callers wait for the lock, then re-check staleness + - No duplicate refreshes possible """ - # If refresh already running, wait for it to complete - if self._refresh_lock.locked(): - logger.debug("Refresh already in progress, waiting...") - async with self._refresh_lock: - pass # Wait for the running refresh to finish - return - try: - async with self.store.pool.acquire() as conn: - # Check how many notes have connection_count = 0 (likely stale) - stale_count = await conn.fetchval( - "SELECT COUNT(*) FROM notes WHERE connection_count = 0" - ) - - # If >50% of notes have count=0, trigger refresh - total_count = await conn.fetchval("SELECT COUNT(*) FROM notes") - - if total_count > 0 and stale_count / total_count > 0.5: - logger.info( - f"{stale_count}/{total_count} notes have stale counts, refreshing..." + async with self._refresh_lock: + async with self.store.pool.acquire() as conn: + stale_count = await conn.fetchval( + "SELECT COUNT(*) FROM notes WHERE connection_count = 0" ) - await self._refresh_all_counts(threshold) + total_count = await conn.fetchval("SELECT COUNT(*) FROM notes") + + if total_count > 0 and stale_count / total_count > 0.5: + logger.info( + f"{stale_count}/{total_count} notes have stale counts, refreshing..." + ) + await self._do_refresh(threshold) except Exception as e: logger.warning(f"Failed to check count freshness: {e}") async def _refresh_all_counts(self, threshold: float): """ - Background task to refresh connection_count for all notes. + Refresh connection_count for all notes (acquires lock). + + Convenience wrapper that acquires _refresh_lock before refreshing. + Prefer _do_refresh() when caller already holds the lock. + """ + async with self._refresh_lock: + await self._do_refresh(threshold) + + async def _do_refresh(self, threshold: float): + """ + Refresh connection_count for all notes (caller must hold _refresh_lock). Uses batched SQL approach instead of O(N²) individual queries. Computes counts in batches to balance memory usage and performance. @@ -185,82 +187,75 @@ async def _refresh_all_counts(self, threshold: float): Args: threshold: Similarity threshold for counting connections - Thread-Safety: - - Acquires self._refresh_lock to ensure exclusive execution - - Blocks until lock available (serializes concurrent refresh requests) - - Lock automatically released after completion or error - Performance: - Processes notes in batches of 100 to avoid memory issues - Each batch uses a single SQL query with vector distance computation - Total complexity: O(N * B) where B = batch size, much better than O(N²) """ - async with self._refresh_lock: # Acquire lock (blocks until available) - logger.info("Starting background connection count refresh (lock acquired)...") + logger.info("Starting connection count refresh...") - try: - distance_threshold = 1.0 - threshold - batch_size = 100 # Process 100 notes at a time + try: + distance_threshold = 1.0 - threshold + batch_size = 100 # Process 100 notes at a time - async with self.store.pool.acquire() as conn: - # Get total count for progress logging - total_notes = await conn.fetchval( - "SELECT COUNT(*) FROM notes WHERE embedding IS NOT NULL" - ) + async with self.store.pool.acquire() as conn: + # Get total count for progress logging + total_notes = await conn.fetchval( + "SELECT COUNT(*) FROM notes WHERE embedding IS NOT NULL" + ) - if total_notes == 0: - logger.info("No notes with embeddings to refresh") - return - - logger.info(f"Refreshing connection counts for {total_notes} notes...") - - # Process in batches using OFFSET/LIMIT - processed = 0 - for offset in range(0, total_notes, batch_size): - # Get batch of note paths - batch_paths = await conn.fetch( - """ - SELECT path FROM notes - WHERE embedding IS NOT NULL - ORDER BY path - LIMIT $1 OFFSET $2 - """, - batch_size, - offset, - ) + if total_notes == 0: + logger.info("No notes with embeddings to refresh") + return + + logger.info(f"Refreshing connection counts for {total_notes} notes...") + + # Process in batches using OFFSET/LIMIT + processed = 0 + for offset in range(0, total_notes, batch_size): + # Get batch of note paths + batch_paths = await conn.fetch( + """ + SELECT path FROM notes + WHERE embedding IS NOT NULL + ORDER BY path + LIMIT $1 OFFSET $2 + """, + batch_size, + offset, + ) - if not batch_paths: - break - - # Update counts for this batch using a single efficient query - # This computes connection counts for all notes in the batch at once - await conn.execute( - """ - UPDATE notes AS n - SET connection_count = subq.cnt, - last_indexed_at = CURRENT_TIMESTAMP - FROM ( - SELECT n1.path, COUNT(n2.path) AS cnt - FROM notes n1 - LEFT JOIN notes n2 ON n1.path != n2.path - AND n2.embedding IS NOT NULL - AND (n1.embedding <=> n2.embedding) <= $1 - WHERE n1.path = ANY($2::text[]) - AND n1.embedding IS NOT NULL - GROUP BY n1.path - ) AS subq - WHERE n.path = subq.path - """, - distance_threshold, - [r["path"] for r in batch_paths], - ) + if not batch_paths: + break + + # Update counts for this batch using a single efficient query + # Paths come from database (already validated on insertion), not user input + await conn.execute( + """ + UPDATE notes AS n + SET connection_count = subq.cnt, + last_indexed_at = CURRENT_TIMESTAMP + FROM ( + SELECT n1.path, COUNT(n2.path) AS cnt + FROM notes n1 + LEFT JOIN notes n2 ON n1.path != n2.path + AND n2.embedding IS NOT NULL + AND (n1.embedding <=> n2.embedding) <= $1 + WHERE n1.path = ANY($2::text[]) + AND n1.embedding IS NOT NULL + GROUP BY n1.path + ) AS subq + WHERE n.path = subq.path + """, + distance_threshold, + [r["path"] for r in batch_paths], + ) - processed += len(batch_paths) - if processed % 500 == 0 or processed == total_notes: - logger.debug(f"Refreshed {processed}/{total_notes} notes") + processed += len(batch_paths) + if processed % 500 == 0 or processed == total_notes: + logger.debug(f"Refreshed {processed}/{total_notes} notes") - logger.success(f"Connection count refresh complete ({total_notes} notes)") + logger.success(f"Connection count refresh complete ({total_notes} notes)") - except Exception as e: - logger.error(f"Connection count refresh failed: {e}") - # Lock automatically released here + except Exception as e: + logger.error(f"Connection count refresh failed: {e}") diff --git a/tests/test_hub_analyzer.py b/tests/test_hub_analyzer.py index 4ac3b08..e9ba3f4 100644 --- a/tests/test_hub_analyzer.py +++ b/tests/test_hub_analyzer.py @@ -177,11 +177,11 @@ async def hold_lock_briefly(): await lock_acquired.wait() # While locked, ensure_fresh_counts should wait (not skip) + # After lock releases, it checks staleness inside the lock await analyzer._ensure_fresh_counts(0.5) - # After waiting for lock release, it should return without querying - # (the wait-and-return path doesn't query the database) - mock_conn.fetchval.assert_not_called() + # After waiting, it should have checked staleness (fetchval called) + assert mock_conn.fetchval.call_count >= 1, "Should check staleness after acquiring lock" await hold_task @@ -247,14 +247,14 @@ async def __aexit__(self, *args): mock_store.pool = MagicMock() mock_store.pool.acquire = MagicMock(return_value=MockAcquire()) - # Mock refresh - analyzer._refresh_all_counts = AsyncMock() + # Mock refresh (now _do_refresh since check+refresh are atomic inside lock) + analyzer._do_refresh = AsyncMock() # Check freshness - refresh is now awaited inline await analyzer._ensure_fresh_counts(0.5) # Should have triggered refresh (60% > 50% threshold) - analyzer._refresh_all_counts.assert_called_once_with(0.5) + analyzer._do_refresh.assert_called_once_with(0.5) @pytest.mark.asyncio diff --git a/tests/test_race_conditions.py b/tests/test_race_conditions.py index 8408f07..435bb0a 100644 --- a/tests/test_race_conditions.py +++ b/tests/test_race_conditions.py @@ -152,17 +152,16 @@ async def test_hub_analyzer_concurrent_refresh_race(): Race condition scenario: - 20 concurrent requests call _ensure_fresh_counts() - - All detect stale counts - - All schedule refresh via asyncio.create_task() - - Only ONE actual vault scan should execute + - First caller acquires lock, checks staleness, runs refresh + - Other callers wait for lock, then re-check staleness (now fresh), skip refresh - Without proper locking, multiple refreshes can run concurrently (expensive!). + The lock serializes both the staleness check and refresh atomically, + preventing duplicate refreshes. """ # Create mock store mock_pool = MagicMock() mock_conn = AsyncMock() - # Mock the connection acquisition class MockAcquire: async def __aenter__(self): return mock_conn @@ -172,46 +171,31 @@ async def __aexit__(self, *args): mock_pool.acquire = MagicMock(return_value=MockAcquire()) - # Mock queries to simulate stale counts (triggers refresh) - # Note: stale_count/total_count must be > 0.5 to trigger refresh - mock_conn.fetchval = AsyncMock( - side_effect=[ - 501, # stale_count (first call) - 501/1000 > 0.5 triggers refresh - 1000, # total_count (second call) - 501, - 1000, - 501, - 1000, - 501, - 1000, # Repeat for concurrent calls - 501, - 1000, - 501, - 1000, - 501, - 1000, - 501, - 1000, - 501, - 1000, - 501, - 1000, - 501, - 1000, - 501, - 1000, - 501, - 1000, - 501, - 1000, - 501, - 1000, - 501, - 1000, - ] - ) + mock_store = MagicMock() + mock_store.pool = mock_pool - # Mock fetch for refresh operation + analyzer = HubAnalyzer(mock_store) + + # Track refresh executions + refresh_count = 0 + + # First caller sees stale counts (501/1000 > 50%), triggers refresh. + # After refresh, subsequent callers see fresh counts (0/1000 = 0%, no refresh needed). + call_count = 0 + + async def mock_fetchval(query): + nonlocal call_count + call_count += 1 + if "connection_count = 0" in query: + # First check returns stale, subsequent checks return fresh + return 501 if call_count <= 1 else 0 + if "COUNT(*) FROM notes" in query: + return 1000 + if "embedding IS NOT NULL" in query: + return 2 + return 0 + + mock_conn.fetchval = mock_fetchval mock_conn.fetch = AsyncMock( return_value=[ {"path": "note1.md", "embedding": [0.1] * 1024}, @@ -220,45 +204,24 @@ async def __aexit__(self, *args): ) mock_conn.execute = AsyncMock() - mock_store = MagicMock() - mock_store.pool = mock_pool - - # Create hub analyzer - analyzer = HubAnalyzer(mock_store) - - # Track number of actual refresh executions - refresh_count = 0 - refresh_lock = asyncio.Lock() - - original_refresh = analyzer._refresh_all_counts + original_do_refresh = analyzer._do_refresh async def tracked_refresh(threshold): nonlocal refresh_count - async with refresh_lock: - refresh_count += 1 - # Call original to test actual logic - await original_refresh(threshold) + refresh_count += 1 + await original_do_refresh(threshold) - analyzer._refresh_all_counts = tracked_refresh + analyzer._do_refresh = tracked_refresh - # Simulate 20 concurrent requests triggering refresh + # Simulate 20 concurrent requests tasks = [analyzer._ensure_fresh_counts(0.5) for _ in range(20)] - await asyncio.gather(*tasks) - # Allow background tasks to complete - await asyncio.sleep(0.5) - - # The design allows multiple refreshes to be scheduled (all concurrent calls see stale counts). - # The lock ensures they execute SERIALLY, not concurrently. - # This test verifies the lock is working: with 20 concurrent calls, - # some will schedule refreshes, but they'll run one at a time. - # We accept that multiple refreshes may run, as long as they don't run concurrently. - # Note: This is a design trade-off - preventing scheduling entirely would require - # a more complex "refresh scheduled" flag with atomic compare-and-swap semantics. - assert refresh_count <= 20, ( - f"More refreshes than requests: got {refresh_count} refreshes for 20 requests. " - "This suggests a bug in the test or logic." + # With atomic check+refresh inside lock, only 1 refresh should run. + # Subsequent callers re-check staleness after acquiring lock and find counts are fresh. + assert refresh_count == 1, ( + f"Expected exactly 1 refresh, got {refresh_count}. " + "The lock should serialize check+refresh atomically." ) From c78451af493f53863c2d856b801dc4a70123f704 Mon Sep 17 00:00:00 2001 From: drewburchfield Date: Fri, 3 Apr 2026 12:57:48 -0500 Subject: [PATCH 3/3] Address Devin findings: fix test mock target and pool connection leak - Fix test_staleness_check_skips_refresh_when_fresh to mock _do_refresh instead of _refresh_all_counts (vacuously true assertion) - Release pool connection before calling _do_refresh to avoid holding an idle connection during the entire refresh operation --- src/hub_analyzer.py | 11 ++++++----- tests/test_hub_analyzer.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/hub_analyzer.py b/src/hub_analyzer.py index 271a3f1..4f94b95 100644 --- a/src/hub_analyzer.py +++ b/src/hub_analyzer.py @@ -158,11 +158,12 @@ async def _ensure_fresh_counts(self, threshold: float): ) total_count = await conn.fetchval("SELECT COUNT(*) FROM notes") - if total_count > 0 and stale_count / total_count > 0.5: - logger.info( - f"{stale_count}/{total_count} notes have stale counts, refreshing..." - ) - await self._do_refresh(threshold) + # Release pool connection before potentially long refresh + if total_count > 0 and stale_count / total_count > 0.5: + logger.info( + f"{stale_count}/{total_count} notes have stale counts, refreshing..." + ) + await self._do_refresh(threshold) except Exception as e: logger.warning(f"Failed to check count freshness: {e}") diff --git a/tests/test_hub_analyzer.py b/tests/test_hub_analyzer.py index e9ba3f4..3c6e9b2 100644 --- a/tests/test_hub_analyzer.py +++ b/tests/test_hub_analyzer.py @@ -277,10 +277,10 @@ async def __aexit__(self, *args): mock_store.pool.acquire = MagicMock(return_value=MockAcquire()) # Mock refresh - analyzer._refresh_all_counts = AsyncMock() + analyzer._do_refresh = AsyncMock() # Check freshness await analyzer._ensure_fresh_counts(0.5) # Should NOT have triggered refresh (20% < 50% threshold) - analyzer._refresh_all_counts.assert_not_called() + analyzer._do_refresh.assert_not_called()