Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions docs/CONCURRENCY.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ObsidianFileWatcher:
**Thread-Safety Guarantees:**
- ✅ Concurrent calls to `get_hub_notes()`: Safe (read-only queries)
- ✅ Concurrent calls to `get_orphaned_notes()`: Safe (read-only queries)
- ✅ Multiple refresh requests: Only ONE refresh runs (others skip)
- ✅ Multiple refresh requests: Only ONE refresh runs (others wait for it)

**Synchronization Mechanism**: Global async lock for refresh operations

Expand All @@ -71,22 +71,16 @@ class HubAnalyzer:
| Scenario | Behavior | Performance |
|----------|----------|-------------|
| 20 concurrent `get_hub_notes()` calls | All execute in parallel (read-only) | Optimal |
| 20 concurrent requests trigger refresh | Only 1 refresh runs, others skip | Optimal |
| Refresh running, new `get_hub_notes()` call | Query executes immediately (no blocking) | Optimal |
| 20 concurrent requests trigger refresh | 1 refresh runs, others wait then re-check | Correct |
| Refresh running, new `get_hub_notes()` call | Waits for refresh, then queries fresh data | Correct |

**Refresh Logic:**
```python
async def _ensure_fresh_counts(self, threshold):
if self._refresh_lock.locked(): # Non-blocking check
return # Skip if refresh already running

# Check staleness, schedule refresh if needed
asyncio.create_task(self._refresh_all_counts(threshold))

async def _refresh_all_counts(self, threshold):
async with self._refresh_lock: # Blocking - waits for lock
# Scan entire vault, update connection_count for all notes
...
async with self._refresh_lock: # Staleness check + refresh are atomic
# Check staleness inside lock to prevent TOCTOU races
if stale_count / total_count > 0.5:
await self._do_refresh(threshold) # Inline, not fire-and-forget
```

**Lock Granularity**: Coarse-grained (one lock for entire refresh operation)
Expand Down Expand Up @@ -175,7 +169,7 @@ pytest tests/test_race_conditions.py::test_hub_analyzer_concurrent_refresh_race
|------|------------------|-------------------|
| File watcher same file | 10 rapid edits | 1 re-index |
| File watcher different files | 50 concurrent files | 50 parallel re-indexes |
| Hub analyzer refresh | 20 concurrent requests | 1 refresh execution |
| Hub analyzer refresh | 20 concurrent requests | 1 refresh, others wait |

---

Expand Down
201 changes: 90 additions & 111 deletions src/hub_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ class HubAnalyzer:
Thread-Safety:
- Uses asyncio.Lock for refresh operations
- Multiple concurrent calls to get_hub_notes/get_orphaned_notes: safe
- Only ONE vault refresh runs at a time (others skip if in progress)
- Only ONE vault refresh runs at a time (others wait for completion)

Performance:
- Refresh is O(N²) where N = number of notes
- Triggered when >50% of notes have stale connection_count
- Runs in background (non-blocking for queries)
- Awaited inline so counts are fresh before queries return
"""

def __init__(self, store: PostgreSQLVectorStore):
Expand All @@ -39,22 +39,6 @@ def __init__(self, store: PostgreSQLVectorStore):
self.store = store
self._refresh_lock = asyncio.Lock() # Replaces refresh_in_progress boolean

def _handle_refresh_error(self, task: asyncio.Task):
"""
Error callback for background refresh tasks.

Logs errors from background connection count refresh without crashing.
Non-fatal - hub/orphan queries will still work with stale counts.

Args:
task: Completed asyncio Task to check for errors
"""
try:
task.result() # Raises exception if task failed
except Exception as e:
logger.error(f"Background connection count refresh failed: {e}", exc_info=True)
# Non-fatal - queries will work with stale counts

async def get_hub_notes(
self, min_connections: int = 10, threshold: float = 0.5, limit: int = 20
) -> list[dict]:
Expand Down Expand Up @@ -156,128 +140,123 @@ async def get_orphaned_notes(

async def _ensure_fresh_counts(self, threshold: float):
"""
Ensure connection counts are fresh (or trigger background refresh).
Ensure connection counts are fresh before querying.

Checks if any notes have stale connection_count (last_indexed_at old).
Triggers background refresh if needed.
Checks staleness and refreshes inline so counts are ready when callers query.
All logic runs inside the lock to prevent TOCTOU races and duplicate refreshes.

Thread-Safety:
- Uses non-blocking lock check to avoid queueing multiple refreshes
- If refresh already running, skips check (other request will handle it)
- Staleness check and refresh are atomic (both inside _refresh_lock)
- Concurrent callers wait for the lock, then re-check staleness
- No duplicate refreshes possible
"""
# Non-blocking lock check - if refresh already running, skip
if self._refresh_lock.locked():
logger.debug("Refresh already in progress, skipping")
return

try:
async with self.store.pool.acquire() as conn:
# Check how many notes have connection_count = 0 (likely stale)
stale_count = await conn.fetchval(
"SELECT COUNT(*) FROM notes WHERE connection_count = 0"
)

# If >50% of notes have count=0, trigger refresh
total_count = await conn.fetchval("SELECT COUNT(*) FROM notes")
async with self._refresh_lock:
async with self.store.pool.acquire() as conn:
stale_count = await conn.fetchval(
"SELECT COUNT(*) FROM notes WHERE connection_count = 0"
)
total_count = await conn.fetchval("SELECT COUNT(*) FROM notes")

# Release pool connection before potentially long refresh
if total_count > 0 and stale_count / total_count > 0.5:
logger.info(
f"{stale_count}/{total_count} notes have stale counts, refreshing..."
)
# Trigger background refresh with error handling
task = asyncio.create_task(self._refresh_all_counts(threshold))
task.add_done_callback(self._handle_refresh_error)
logger.debug("Scheduled background refresh task")
await self._do_refresh(threshold)

except Exception as e:
logger.warning(f"Failed to check count freshness: {e}")

async def _refresh_all_counts(self, threshold: float):
"""
Background task to refresh connection_count for all notes.
Refresh connection_count for all notes (acquires lock).

Convenience wrapper that acquires _refresh_lock before refreshing.
Prefer _do_refresh() when caller already holds the lock.
"""
async with self._refresh_lock:
await self._do_refresh(threshold)

async def _do_refresh(self, threshold: float):
"""
Refresh connection_count for all notes (caller must hold _refresh_lock).

Uses batched SQL approach instead of O(N²) individual queries.
Computes counts in batches to balance memory usage and performance.

Args:
threshold: Similarity threshold for counting connections

Thread-Safety:
- Acquires self._refresh_lock to ensure exclusive execution
- Blocks until lock available (serializes concurrent refresh requests)
- Lock automatically released after completion or error

Performance:
- Processes notes in batches of 100 to avoid memory issues
- Each batch uses a single SQL query with vector distance computation
- Total complexity: O(N * B) where B = batch size, much better than O(N²)
"""
async with self._refresh_lock: # Acquire lock (blocks until available)
logger.info("Starting background connection count refresh (lock acquired)...")
logger.info("Starting connection count refresh...")

try:
distance_threshold = 1.0 - threshold
batch_size = 100 # Process 100 notes at a time
try:
distance_threshold = 1.0 - threshold
batch_size = 100 # Process 100 notes at a time

async with self.store.pool.acquire() as conn:
# Get total count for progress logging
total_notes = await conn.fetchval(
"SELECT COUNT(*) FROM notes WHERE embedding IS NOT NULL"
async with self.store.pool.acquire() as conn:
# Get total count for progress logging
total_notes = await conn.fetchval(
"SELECT COUNT(*) FROM notes WHERE embedding IS NOT NULL"
)

if total_notes == 0:
logger.info("No notes with embeddings to refresh")
return

logger.info(f"Refreshing connection counts for {total_notes} notes...")

# Process in batches using OFFSET/LIMIT
processed = 0
for offset in range(0, total_notes, batch_size):
# Get batch of note paths
batch_paths = await conn.fetch(
"""
SELECT path FROM notes
WHERE embedding IS NOT NULL
ORDER BY path
LIMIT $1 OFFSET $2
""",
batch_size,
offset,
)

if not batch_paths:
break

# Update counts for this batch using a single efficient query
# Paths come from database (already validated on insertion), not user input
await conn.execute(
"""
UPDATE notes AS n
SET connection_count = subq.cnt,
last_indexed_at = CURRENT_TIMESTAMP
FROM (
SELECT n1.path, COUNT(n2.path) AS cnt
FROM notes n1
LEFT JOIN notes n2 ON n1.path != n2.path
AND n2.embedding IS NOT NULL
AND (n1.embedding <=> n2.embedding) <= $1
WHERE n1.path = ANY($2::text[])
AND n1.embedding IS NOT NULL
GROUP BY n1.path
) AS subq
WHERE n.path = subq.path
""",
distance_threshold,
[r["path"] for r in batch_paths],
)

if total_notes == 0:
logger.info("No notes with embeddings to refresh")
return

logger.info(f"Refreshing connection counts for {total_notes} notes...")

# Process in batches using OFFSET/LIMIT
processed = 0
for offset in range(0, total_notes, batch_size):
# Get batch of note paths
batch_paths = await conn.fetch(
"""
SELECT path FROM notes
WHERE embedding IS NOT NULL
ORDER BY path
LIMIT $1 OFFSET $2
""",
batch_size,
offset,
)

if not batch_paths:
break

# Update counts for this batch using a single efficient query
# This computes connection counts for all notes in the batch at once
await conn.execute(
"""
UPDATE notes AS n
SET connection_count = subq.cnt,
last_indexed_at = CURRENT_TIMESTAMP
FROM (
SELECT n1.path, COUNT(n2.path) AS cnt
FROM notes n1
LEFT JOIN notes n2 ON n1.path != n2.path
AND n2.embedding IS NOT NULL
AND (n1.embedding <=> n2.embedding) <= $1
WHERE n1.path = ANY($2::text[])
AND n1.embedding IS NOT NULL
GROUP BY n1.path
) AS subq
WHERE n.path = subq.path
""",
distance_threshold,
[r["path"] for r in batch_paths],
)

processed += len(batch_paths)
if processed % 500 == 0 or processed == total_notes:
logger.debug(f"Refreshed {processed}/{total_notes} notes")

logger.success(f"Connection count refresh complete ({total_notes} notes)")

except Exception as e:
logger.error(f"Connection count refresh failed: {e}")
# Lock automatically released here
processed += len(batch_paths)
if processed % 500 == 0 or processed == total_notes:
logger.debug(f"Refreshed {processed}/{total_notes} notes")

logger.success(f"Connection count refresh complete ({total_notes} notes)")

except Exception as e:
logger.error(f"Connection count refresh failed: {e}")
45 changes: 27 additions & 18 deletions tests/test_hub_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ async def tracked_refresh(threshold):


@pytest.mark.asyncio
async def test_ensure_fresh_counts_skips_if_locked(mock_store):
"""Test that _ensure_fresh_counts skips check if refresh already running."""
async def test_ensure_fresh_counts_waits_if_locked(mock_store):
"""Test that _ensure_fresh_counts waits for lock when refresh already running."""
analyzer = HubAnalyzer(mock_store)

# Mock database
Expand All @@ -164,13 +164,26 @@ async def __aexit__(self, *args):
mock_store.pool = MagicMock()
mock_store.pool.acquire = MagicMock(return_value=MockAcquire())

# Acquire lock to simulate refresh in progress
async with analyzer._refresh_lock:
# While locked, ensure_fresh_counts should skip immediately
await analyzer._ensure_fresh_counts(0.5)
# Acquire lock in a background task to simulate refresh in progress,
# then release it after a short delay so _ensure_fresh_counts can proceed
lock_acquired = asyncio.Event()

# Should not have queried database (skipped due to lock)
mock_conn.fetchval.assert_not_called()
async def hold_lock_briefly():
async with analyzer._refresh_lock:
lock_acquired.set()
await asyncio.sleep(0.05) # Hold lock briefly

hold_task = asyncio.create_task(hold_lock_briefly())
await lock_acquired.wait()

# While locked, ensure_fresh_counts should wait (not skip)
# After lock releases, it checks staleness inside the lock
await analyzer._ensure_fresh_counts(0.5)

# After waiting, it should have checked staleness (fetchval called)
assert mock_conn.fetchval.call_count >= 1, "Should check staleness after acquiring lock"

await hold_task


@pytest.mark.asyncio
Expand Down Expand Up @@ -234,18 +247,14 @@ async def __aexit__(self, *args):
mock_store.pool = MagicMock()
mock_store.pool.acquire = MagicMock(return_value=MockAcquire())

# Mock refresh
analyzer._refresh_all_counts = AsyncMock()
# Mock refresh (now _do_refresh since check+refresh are atomic inside lock)
analyzer._do_refresh = AsyncMock()

# Check freshness
# Check freshness - refresh is now awaited inline
await analyzer._ensure_fresh_counts(0.5)

# Allow background task to be scheduled
await asyncio.sleep(0.1)

# Should have triggered refresh (60% > 50% threshold)
# Note: refresh runs via asyncio.create_task, so we check it was called
# The actual scheduling depends on event loop timing
analyzer._do_refresh.assert_called_once_with(0.5)


@pytest.mark.asyncio
Expand All @@ -268,10 +277,10 @@ async def __aexit__(self, *args):
mock_store.pool.acquire = MagicMock(return_value=MockAcquire())

# Mock refresh
analyzer._refresh_all_counts = AsyncMock()
analyzer._do_refresh = AsyncMock()

# Check freshness
await analyzer._ensure_fresh_counts(0.5)

# Should NOT have triggered refresh (20% < 50% threshold)
analyzer._refresh_all_counts.assert_not_called()
analyzer._do_refresh.assert_not_called()
Loading
Loading