Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 240 additions & 28 deletions grafi/workflows/impl/async_node_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,265 @@
import asyncio
from collections import defaultdict
from typing import Dict
from typing import Optional
from typing import Set

from loguru import logger


class AsyncNodeTracker:
"""
Central tracker for workflow activity and quiescence detection.

Design: All tracking calls come from the ORCHESTRATOR layer,
not from TopicBase. This keeps topics as pure message queues.

Quiescence = (no active nodes) AND (no uncommitted messages) AND (work done)

Usage in workflow:
# In publish_events():
tracker.on_messages_published(len(published_events))

# In _commit_events():
tracker.on_messages_committed(len(events))

# In node processing:
await tracker.enter(node_name)
... process ...
await tracker.leave(node_name)
"""

def __init__(self) -> None:
self._active: set[str] = set()
self._processing_count: Dict[str, int] = defaultdict(
int
) # Track how many times each node processed
# Node activity tracking
self._active: Set[str] = set()
self._processing_count: Dict[str, int] = defaultdict(int)

# Message tracking (uncommitted = published but not yet committed)
self._uncommitted_messages: int = 0

# Synchronization
self._cond = asyncio.Condition()
self._idle_event = asyncio.Event()
# Set the event initially since we start in idle state
self._idle_event.set()
self._quiescence_event = asyncio.Event()

# Work tracking (prevents premature quiescence before any work)
self._total_committed: int = 0
self._has_started: bool = False

# Force stop flag (for explicit workflow stop)
self._force_stopped: bool = False

def reset(self) -> None:
"""
Reset the tracker to its initial state.
"""
self._active = set()
self._processing_count = defaultdict(int)
"""Reset for a new workflow run."""
self._active.clear()
self._processing_count.clear()
self._uncommitted_messages = 0
self._cond = asyncio.Condition()
self._idle_event = asyncio.Event()
# Set the event initially since we start in idle state
self._idle_event.set()
self._quiescence_event = asyncio.Event()
self._total_committed = 0
self._has_started = False
self._force_stopped = False

# ─────────────────────────────────────────────────────────────────────────
# Node Lifecycle (called from _invoke_node)
# ─────────────────────────────────────────────────────────────────────────

async def enter(self, node_name: str) -> None:
"""Called when a node begins processing."""
async with self._cond:
self._idle_event.clear()
self._has_started = True
self._quiescence_event.clear()
self._active.add(node_name)
self._processing_count[node_name] += 1

async def leave(self, node_name: str) -> None:
"""Called when a node finishes processing."""
async with self._cond:
self._active.discard(node_name)
if not self._active:
self._idle_event.set()
self._cond.notify_all()
self._check_quiescence_unlocked()
self._cond.notify_all()

async def wait_idle_event(self) -> None:
# ─────────────────────────────────────────────────────────────────────────
# Message Tracking (called from orchestrator utilities)
# ─────────────────────────────────────────────────────────────────────────

async def on_messages_published(self, count: int = 1, source: str = "") -> None:
"""
Called when messages are published to topics.

Call site: publish_events() in utils.py
"""
if count <= 0:
return
async with self._cond:
self._has_started = True
self._quiescence_event.clear()
self._uncommitted_messages += count

logger.debug(
f"Tracker: {count} messages published from {source} (uncommitted={self._uncommitted_messages})"
)

async def on_messages_committed(self, count: int = 1, source: str = "") -> None:
"""
Wait until the tracker is idle, meaning no active nodes.
This is useful for synchronization points in workflows.
Called when messages are committed (consumed and acknowledged).

Call site: _commit_events() in EventDrivenWorkflow
"""
if count <= 0:
return
async with self._cond:
self._uncommitted_messages = max(0, self._uncommitted_messages - count)
self._total_committed += count
self._check_quiescence_unlocked()

logger.debug(
f"Tracker: {count} messages committed from {source} "
f"(uncommitted={self._uncommitted_messages}, total={self._total_committed})"
)
self._cond.notify_all()

# Aliases for clarity
async def on_message_published(self) -> None:
"""Single message version."""
await self.on_messages_published(1)

async def on_message_committed(self) -> None:
"""Single message version."""
await self.on_messages_committed(1)

# ─────────────────────────────────────────────────────────────────────────
# Quiescence Detection
# ─────────────────────────────────────────────────────────────────────────

def _check_quiescence_unlocked(self) -> None:
"""
Check and signal quiescence if all conditions met.

MUST be called with self._cond lock held.
"""
is_quiescent = self._is_quiescent_unlocked()
logger.debug(
f"Tracker: checking quiescence - active={list(self._active)}, "
f"uncommitted={self._uncommitted_messages}, "
f"has_started={self._has_started}, "
f"total_committed={self._total_committed}, "
f"is_quiescent={is_quiescent}"
)
if is_quiescent:
logger.info(
f"Tracker: quiescence detected (committed={self._total_committed})"
)
self._quiescence_event.set()

def _is_quiescent_unlocked(self) -> bool:
"""
await self._idle_event.wait()
Internal quiescence check without lock.

def is_idle(self) -> bool:
return not self._active
MUST be called with self._cond lock held.

def get_activity_count(self) -> int:
"""Get total processing count across all nodes"""
return sum(self._processing_count.values())
True when workflow is truly idle:
- No nodes actively processing
- No messages waiting to be committed
- At least some work was done
"""
return (
not self._active
and self._uncommitted_messages == 0
and self._has_started
and self._total_committed > 0
)

async def is_quiescent(self) -> bool:
"""
True when workflow is truly idle:
- No nodes actively processing
- No messages waiting to be committed
- At least some work was done

This method acquires the lock to ensure consistent reads.
"""
async with self._cond:
return self._is_quiescent_unlocked()

def _should_terminate_unlocked(self) -> bool:
"""
Internal termination check without lock.

MUST be called with self._cond lock held.
"""
return self._is_quiescent_unlocked() or self._force_stopped

async def should_terminate(self) -> bool:
"""
True when workflow should stop iteration.
Either natural quiescence or explicit force stop.

This method acquires the lock to ensure consistent reads.
"""
async with self._cond:
return self._should_terminate_unlocked()

async def force_stop(self) -> None:
"""
Force the workflow to stop immediately (async version with lock).
Called when workflow.stop() is invoked from async context.
"""
async with self._cond:
logger.info("Tracker: force stop requested")
self._force_stopped = True
self._quiescence_event.set()
self._cond.notify_all()

def force_stop_sync(self) -> None:
"""
Force the workflow to stop immediately (sync version).

This is a synchronous version for use from sync contexts (e.g., stop() method).
It sets the stop flag and event without acquiring the async lock.
This is safe because:
1. Setting _force_stopped to True is atomic for the stop signal
2. asyncio.Event.set() is thread-safe
3. Readers will see the updated state on their next lock acquisition
"""
logger.info("Tracker: force stop requested (sync)")
self._force_stopped = True
self._quiescence_event.set()

async def is_idle(self) -> bool:
"""Legacy: just checks if no active nodes."""
async with self._cond:
return not self._active

async def wait_for_quiescence(self, timeout: Optional[float] = None) -> bool:
"""Wait until quiescent. Returns False on timeout."""
try:
if timeout:
await asyncio.wait_for(self._quiescence_event.wait(), timeout)
else:
await self._quiescence_event.wait()
return True
except asyncio.TimeoutError:
return False

async def wait_idle_event(self) -> None:
"""Legacy compatibility."""
await self._quiescence_event.wait()

# ─────────────────────────────────────────────────────────────────────────
# Metrics
# ─────────────────────────────────────────────────────────────────────────

async def get_activity_count(self) -> int:
"""Total processing count across all nodes."""
async with self._cond:
return sum(self._processing_count.values())

async def get_metrics(self) -> Dict:
"""Detailed metrics for debugging."""
async with self._cond:
return {
"active_nodes": list(self._active),
"uncommitted_messages": self._uncommitted_messages,
"total_committed": self._total_committed,
"is_quiescent": self._is_quiescent_unlocked(),
}
Loading
Loading