Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 34 additions & 20 deletions grafi/common/containers/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,38 +29,52 @@ def __init__(self) -> None:
# Per-instance attributes:
self._event_store: Optional[EventStore] = None
self._tracer: Optional[Tracer] = None
# Lock for thread-safe lazy initialization of properties
self._init_lock: threading.Lock = threading.Lock()

def register_event_store(self, event_store: EventStore) -> None:
"""Override the default EventStore implementation."""
if isinstance(event_store, EventStoreInMemory):
logger.warning(
"Using EventStoreInMemory. This is ONLY suitable for local testing but not for production."
)
self._event_store = event_store
with self._init_lock:
if isinstance(event_store, EventStoreInMemory):
logger.warning(
"Using EventStoreInMemory. This is ONLY suitable for local testing but not for production."
)
self._event_store = event_store

def register_tracer(self, tracer: Tracer) -> None:
"""Override the default Tracer implementation."""
self._tracer = tracer
with self._init_lock:
self._tracer = tracer

@property
def event_store(self) -> EventStore:
if self._event_store is None:
logger.warning(
"Using EventStoreInMemory. This is ONLY suitable for local testing but not for production."
)
self._event_store = EventStoreInMemory()
return self._event_store
# Fast path: already initialized
if self._event_store is not None:
return self._event_store
# Slow path: initialize with lock (double-checked locking)
with self._init_lock:
if self._event_store is None:
logger.warning(
"Using EventStoreInMemory. This is ONLY suitable for local testing but not for production."
)
self._event_store = EventStoreInMemory()
return self._event_store

@property
def tracer(self) -> Tracer:
if self._tracer is None:
self._tracer = setup_tracing(
tracing_options=TracingOptions.AUTO,
collector_endpoint="localhost",
collector_port=4317,
project_name="grafi-trace",
)
return self._tracer
# Fast path: already initialized
if self._tracer is not None:
return self._tracer
# Slow path: initialize with lock (double-checked locking)
with self._init_lock:
if self._tracer is None:
self._tracer = setup_tracing(
tracing_options=TracingOptions.AUTO,
collector_endpoint="localhost",
collector_port=4317,
project_name="grafi-trace",
)
return self._tracer


container: Container = Container()
15 changes: 13 additions & 2 deletions grafi/common/models/async_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ def __init__(self, source: AsyncGenerator[ConsumeFromTopicEvent, None]):
self._done = asyncio.Event()
self._started = False
self._exc: Optional[BaseException] = None
self._producer_task: Optional[asyncio.Task] = None

def _ensure_started(self) -> None:
if not self._started:
loop = asyncio.get_running_loop()
loop.create_task(self._producer())
self._producer_task = loop.create_task(self._producer())
self._started = True

async def _producer(self) -> None:
Expand Down Expand Up @@ -94,10 +95,20 @@ async def to_list(self) -> list[ConsumeFromTopicEvent]:
return result if isinstance(result, list) else [result]

async def aclose(self) -> None:
"""Attempt to close the underlying async generator (if any)."""
"""Cancel producer task and close the underlying async generator."""
# Cancel the producer task if it's running
if self._producer_task is not None and not self._producer_task.done():
self._producer_task.cancel()
try:
await self._producer_task
except asyncio.CancelledError:
# The task was cancelled by aclose(); a CancelledError here is expected.
pass
# Close the underlying source generator
try:
await self._source.aclose()
except Exception:
# Best-effort cleanup: ignore errors from closing the underlying source.
pass


Expand Down
42 changes: 21 additions & 21 deletions grafi/tools/llms/impl/claude_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,29 +102,29 @@ async def invoke(
input_data: Messages,
) -> MsgsAGen:
messages, tools = self.prepare_api_input(input_data)
client = AsyncAnthropic(api_key=self.api_key)

try:
if self.is_streaming:
async with client.messages.stream(
max_tokens=self.max_tokens,
model=self.model,
messages=messages,
tools=tools,
**self.chat_params,
) as stream:
async for event in stream:
if event.type == "text":
yield self.to_stream_messages(event.text)
else:
resp: AnthropicMessage = await client.messages.create(
max_tokens=self.max_tokens,
model=self.model,
messages=messages,
tools=tools,
**self.chat_params,
)
yield self.to_messages(resp)
async with AsyncAnthropic(api_key=self.api_key) as client:
if self.is_streaming:
async with client.messages.stream(
max_tokens=self.max_tokens,
model=self.model,
messages=messages,
tools=tools,
**self.chat_params,
) as stream:
async for event in stream:
if event.type == "text":
yield self.to_stream_messages(event.text)
else:
resp: AnthropicMessage = await client.messages.create(
max_tokens=self.max_tokens,
model=self.model,
messages=messages,
tools=tools,
**self.chat_params,
)
yield self.to_messages(resp)

except asyncio.CancelledError:
raise
Expand Down
2 changes: 1 addition & 1 deletion grafi/tools/llms/impl/gemini_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class GeminiTool(LLM):
name: str = Field(default="GeminiTool")
type: str = Field(default="GeminiTool")
api_key: Optional[str] = Field(default_factory=lambda: os.getenv("GEMINI_API_KEY"))
model: str = Field(default="gemini-2.0-flash-lite")
model: str = Field(default="gemini-2.5-flash-lite")

@classmethod
def builder(cls) -> "GeminiToolBuilder":
Expand Down
49 changes: 24 additions & 25 deletions grafi/tools/llms/impl/openai_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,31 +107,30 @@ async def invoke(
) -> MsgsAGen:
api_messages, api_tools = self.prepare_api_input(input_data)
try:
client = AsyncClient(api_key=self.api_key)

if self.is_streaming:
async for chunk in await client.chat.completions.create(
model=self.model,
messages=api_messages,
tools=api_tools,
stream=True,
**self.chat_params,
):
yield self.to_stream_messages(chunk)
else:
req_func = (
client.chat.completions.create
if not self.structured_output
else client.beta.chat.completions.parse
)
response: ChatCompletion = await req_func(
model=self.model,
messages=api_messages,
tools=api_tools,
**self.chat_params,
)

yield self.to_messages(response)
async with AsyncClient(api_key=self.api_key) as client:
if self.is_streaming:
async for chunk in await client.chat.completions.create(
model=self.model,
messages=api_messages,
tools=api_tools,
stream=True,
**self.chat_params,
):
yield self.to_stream_messages(chunk)
else:
req_func = (
client.chat.completions.create
if not self.structured_output
else client.beta.chat.completions.parse
)
response: ChatCompletion = await req_func(
model=self.model,
messages=api_messages,
tools=api_tools,
**self.chat_params,
)

yield self.to_messages(response)
except asyncio.CancelledError:
raise # let caller handle
except OpenAIError as exc:
Expand Down
13 changes: 11 additions & 2 deletions grafi/topics/queue_impl/in_mem_topic_event_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def fetch(

async with self._cond:
# If timeout is 0 or None and no data, return immediately
while not await self.can_consume(consumer_id):
while not self._can_consume_unlocked(consumer_id):
try:
logger.debug(
f"Consumer {consumer_id} waiting for new messages with timeout={timeout}"
Expand Down Expand Up @@ -109,8 +109,17 @@ async def reset(self) -> None:
self._consumed = defaultdict(int)
self._committed = defaultdict(lambda: -1)

def _can_consume_unlocked(self, consumer_id: str) -> bool:
"""
Internal check without lock. MUST be called with self._cond held.
"""
return self._consumed[consumer_id] < len(self._records)

async def can_consume(self, consumer_id: str) -> bool:
"""
Check if there are events available for consumption by a consumer asynchronously.

This method acquires the lock to ensure consistent reads of shared state.
"""
return self._consumed[consumer_id] < len(self._records)
async with self._cond:
return self._can_consume_unlocked(consumer_id)
13 changes: 12 additions & 1 deletion grafi/topics/topic_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,18 @@ async def publish_data(
"""
Publish data to the topic if it meets the condition.
"""
if self.condition(publish_event):
try:
condition_met = self.condition(publish_event)
except Exception as e:
# Condition evaluation failed (e.g., IndexError on empty data)
# Treat as condition not met
logger.debug(
f"[{self.name}] Condition evaluation failed: {e}. "
"Treating as condition not met."
)
condition_met = False

if condition_met:
event = publish_event.model_copy(
update={
"name": self.name,
Expand Down
62 changes: 37 additions & 25 deletions grafi/workflows/impl/async_node_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,55 @@ def __init__(self) -> None:
self._cond = asyncio.Condition()
self._quiescence_event = asyncio.Event()

# Work tracking (prevents premature quiescence before any work)
self._total_committed: int = 0
self._has_started: bool = False

# Force stop flag (for explicit workflow stop)
self._force_stopped: bool = False

def reset(self) -> None:
"""Reset for a new workflow run."""
"""
Reset for a new workflow run.

Note: This is a sync reset that replaces primitives. It should only be
called when no coroutines are waiting on the old primitives (e.g., at
the start of a new workflow invocation before any tasks are spawned).
"""
self._active.clear()
self._processing_count.clear()
self._uncommitted_messages = 0
self._cond = asyncio.Condition()
self._quiescence_event = asyncio.Event()
self._total_committed = 0
self._has_started = False
self._force_stopped = False

async def reset_async(self) -> None:
"""
Reset for a new workflow run (async version).

This version properly wakes any waiting coroutines before resetting,
preventing deadlocks if called while the workflow is still running.
"""
async with self._cond:
# Wake all waiters so they can exit gracefully
self._force_stopped = True
self._quiescence_event.set()
self._cond.notify_all()

# Give waiters a chance to wake up and exit
await asyncio.sleep(0)

# Now safe to reset state
async with self._cond:
self._active.clear()
self._processing_count.clear()
self._uncommitted_messages = 0
self._force_stopped = False
self._quiescence_event.clear()

# ─────────────────────────────────────────────────────────────────────────
# Node Lifecycle (called from _invoke_node)
# ─────────────────────────────────────────────────────────────────────────

async def enter(self, node_name: str) -> None:
"""Called when a node begins processing."""
async with self._cond:
self._has_started = True
self._quiescence_event.clear()
self._active.add(node_name)
self._processing_count[node_name] += 1
Expand All @@ -94,7 +117,6 @@ async def on_messages_published(self, count: int = 1, source: str = "") -> None:
if count <= 0:
return
async with self._cond:
self._has_started = True
self._quiescence_event.clear()
self._uncommitted_messages += count

Expand All @@ -112,13 +134,9 @@ async def on_messages_committed(self, count: int = 1, source: str = "") -> None:
return
async with self._cond:
self._uncommitted_messages = max(0, self._uncommitted_messages - count)
self._total_committed += count
self._check_quiescence_unlocked()

logger.debug(
f"Tracker: {count} messages committed from {source} "
f"(uncommitted={self._uncommitted_messages}, total={self._total_committed})"
)
logger.debug(f"Tracker: {count} messages committed from {source}")
self._cond.notify_all()

# Aliases for clarity
Expand All @@ -144,14 +162,9 @@ def _check_quiescence_unlocked(self) -> None:
logger.debug(
f"Tracker: checking quiescence - active={list(self._active)}, "
f"uncommitted={self._uncommitted_messages}, "
f"has_started={self._has_started}, "
f"total_committed={self._total_committed}, "
f"is_quiescent={is_quiescent}"
)
if is_quiescent:
logger.info(
f"Tracker: quiescence detected (committed={self._total_committed})"
)
self._quiescence_event.set()

def _is_quiescent_unlocked(self) -> bool:
Expand All @@ -165,12 +178,12 @@ def _is_quiescent_unlocked(self) -> bool:
- No messages waiting to be committed
- At least some work was done
"""
return (
not self._active
and self._uncommitted_messages == 0
and self._has_started
and self._total_committed > 0
is_quiescent = not self._active and self._uncommitted_messages == 0
logger.debug(
f"Tracker: _is_quiescent_unlocked check - active={list(self._active)}, "
f"uncommitted={self._uncommitted_messages}, is_quiescent={is_quiescent}"
)
return is_quiescent

async def is_quiescent(self) -> bool:
"""
Expand Down Expand Up @@ -263,6 +276,5 @@ async def get_metrics(self) -> Dict:
return {
"active_nodes": list(self._active),
"uncommitted_messages": self._uncommitted_messages,
"total_committed": self._total_committed,
"is_quiescent": self._is_quiescent_unlocked(),
}
Loading
Loading