From ea3151ad784c0dd0688578a729a69481f64cd41c Mon Sep 17 00:00:00 2001 From: Vitor Hugo Date: Sat, 21 Mar 2026 00:39:16 -0300 Subject: [PATCH] fix: resolve task leak, unsafe auto-install, fragile exception handling, and add routing tests - dispatcher: track handler tasks in a set to prevent GC mid-execution and log unhandled exceptions via done callbacks - base: add handler_tasks set, pass it to dispatch_to_handlers, cancel all handler tasks on stop() - load_balancer: replace silent pip auto-install with ImportError pointing to pip install genesis[redis] - inbound: remove inner try/except that tested tracer errors by string matching; OTel API guarantees start_as_current_span never raises - tests: add test_routing.py with 18 tests covering ChannelRoutingStrategy, GlobalRoutingStrategy, CompositeRoutingStrategy, and dispatch_to_handlers Co-Authored-By: Claude Sonnet 4.6 --- genesis/group/load_balancer.py | 35 +-- genesis/inbound.py | 19 +- genesis/protocol/base.py | 6 +- genesis/protocol/routing/dispatcher.py | 29 ++- tests/test_inbound.py | 10 - tests/test_routing.py | 303 +++++++++++++++++++++++++ 6 files changed, 344 insertions(+), 58 deletions(-) create mode 100644 tests/test_routing.py diff --git a/genesis/group/load_balancer.py b/genesis/group/load_balancer.py index 3104163..290d07c 100644 --- a/genesis/group/load_balancer.py +++ b/genesis/group/load_balancer.py @@ -8,7 +8,6 @@ from __future__ import annotations import asyncio -import sys from typing import Protocol, List, Optional, Any, Awaitable from abc import ABC, abstractmethod @@ -23,7 +22,7 @@ async def _create_redis_client(url: str = "redis://localhost:6379") -> Any: """ - Create a Redis async client, installing redis package if needed. + Create a Redis async client. Internal helper function. @@ -34,35 +33,15 @@ async def _create_redis_client(url: str = "redis://localhost:6379") -> Any: Redis async client instance Raises: - RuntimeError: If redis package cannot be installed or imported + ImportError: If redis package is not installed """ - # Import here to handle optional dependency try: import redis.asyncio as redis_module except ImportError: - # Try to install redis automatically - try: - proc = await asyncio.create_subprocess_exec( - sys.executable, - "-m", - "pip", - "install", - "redis>=5.0.0", - stdout=asyncio.subprocess.DEVNULL, - stderr=asyncio.subprocess.DEVNULL, - ) - await proc.wait() - if proc.returncode != 0: - raise RuntimeError( - "Redis package is required for RedisLoadBalancer. " - "Install it with: pip install redis" - ) - import redis.asyncio as redis_module - except (OSError, ImportError): - raise RuntimeError( - "Redis package is required for RedisLoadBalancer. " - "Install it with: pip install redis" - ) + raise ImportError( + "The redis package is required for RedisLoadBalancer. " + "Install it with: pip install genesis[redis]" + ) return await redis_module.from_url(url) @@ -171,7 +150,7 @@ class RedisLoadBalancer: Redis-based load balancer backend. Tracks call counts in Redis. Suitable for horizontal scaling. - The redis package is automatically installed when needed. + Requires the redis extra: pip install genesis[redis] Args: url: Redis connection URL (default: "redis://localhost:6379") diff --git a/genesis/inbound.py b/genesis/inbound.py index e67fcc3..2c7b8ae 100644 --- a/genesis/inbound.py +++ b/genesis/inbound.py @@ -94,18 +94,13 @@ async def authenticate(self) -> None: async def start(self) -> None: """Initiates an authenticated connection to a freeswitch server.""" try: - try: - with tracer.start_as_current_span( - "inbound_connect", - attributes={ - "net.peer.name": self.host, - "net.peer.port": self.port, - }, - ): - await self._connect() - except Exception as e: - if "tracer" not in str(e).lower(): - raise + with tracer.start_as_current_span( + "inbound_connect", + attributes={ + "net.peer.name": self.host, + "net.peer.port": self.port, + }, + ): await self._connect() except TimeoutError: logger.debug("A timeout occurred when trying to connect to the freeswitch.") diff --git a/genesis/protocol/base.py b/genesis/protocol/base.py index 86cb4d9..4e154cd 100644 --- a/genesis/protocol/base.py +++ b/genesis/protocol/base.py @@ -72,6 +72,7 @@ def __init__(self): self.writer: Optional[StreamWriter] = None self.handlers: Dict[str, List[EventHandler]] = {} self.channel_registry: Dict[str, List[EventHandler]] = {} + self.handler_tasks: set[Task[Any]] = set() # Initialize routing strategy (Strategy Pattern) self.routing_strategy = CompositeRoutingStrategy( @@ -100,6 +101,9 @@ async def stop(self) -> None: self.is_connected = False await self._cancel_task(self.producer, "event producer") await self._cancel_task(self.consumer, "event consumer") + for task in list(self.handler_tasks): + await self._cancel_task(task, "handler task") + self.handler_tasks.clear() async def _cancel_task( self, task: Optional[Task[Any]], label: str = "task" @@ -190,7 +194,7 @@ async def _process_one_event(self, event: ESLEvent) -> None: handlers, _ = await self.routing_strategy.route(event) if handlers: - dispatch_to_handlers(handlers, event) + dispatch_to_handlers(handlers, event, self.handler_tasks) def on( self, diff --git a/genesis/protocol/routing/dispatcher.py b/genesis/protocol/routing/dispatcher.py index 48bf468..8f47131 100644 --- a/genesis/protocol/routing/dispatcher.py +++ b/genesis/protocol/routing/dispatcher.py @@ -5,23 +5,38 @@ Helper for dispatching events to handlers asynchronously. """ -from asyncio import create_task, to_thread, iscoroutinefunction -from typing import List +from asyncio import Task, create_task, to_thread, iscoroutinefunction +from typing import List, Optional, Set, Any +from genesis.observability import logger from genesis.protocol.parser import ESLEvent from genesis.types import EventHandler -def dispatch_to_handlers(handlers: List[EventHandler], event: ESLEvent) -> None: - """Dispatch event to all handlers asynchronously (fire-and-forget tasks). +def _handler_done_callback(task_set: Set[Task[Any]], task: Task[Any]) -> None: + task_set.discard(task) + if not task.cancelled() and task.exception() is not None: + logger.error(f"Unhandled exception in event handler: {task.exception()}") + + +def dispatch_to_handlers( + handlers: List[EventHandler], + event: ESLEvent, + task_set: Optional[Set[Task[Any]]] = None, +) -> None: + """Dispatch event to all handlers asynchronously. Args: handlers: List of event handlers event: The ESL event to dispatch + task_set: Optional set to track live tasks (prevents GC and logs exceptions) """ - _tasks: list = [] for handler in handlers: if iscoroutinefunction(handler): - _tasks.append(create_task(handler(event))) + task = create_task(handler(event)) else: - _tasks.append(create_task(to_thread(handler, event))) + task = create_task(to_thread(handler, event)) + + if task_set is not None: + task_set.add(task) + task.add_done_callback(lambda t: _handler_done_callback(task_set, t)) diff --git a/tests/test_inbound.py b/tests/test_inbound.py index 1705ca8..00f6b52 100644 --- a/tests/test_inbound.py +++ b/tests/test_inbound.py @@ -131,16 +131,6 @@ async def test_inbound_client_send_command_error(freeswitch): await client.send("uptime") -async def test_inbound_tracer_fallback(freeswitch): - async with freeswitch: - with patch( - "genesis.inbound.tracer.start_as_current_span", - side_effect=Exception("Tracer error"), - ): - async with Inbound(*freeswitch.address) as client: - assert client.is_connected - - async def test_inbound_metrics_error_on_start(freeswitch): async with freeswitch: with patch( diff --git a/tests/test_routing.py b/tests/test_routing.py new file mode 100644 index 0000000..1546b46 --- /dev/null +++ b/tests/test_routing.py @@ -0,0 +1,303 @@ +""" +Tests for routing strategies and dispatcher. +""" + +import asyncio +import logging +from asyncio import Task +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from genesis.protocol.parser import ESLEvent +from genesis.protocol.routing.channel import ChannelRoutingStrategy +from genesis.protocol.routing.global_ import GlobalRoutingStrategy +from genesis.protocol.routing.composite import CompositeRoutingStrategy +from genesis.protocol.routing.dispatcher import dispatch_to_handlers + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_event(**kwargs: str) -> ESLEvent: + event = ESLEvent() + event.update(kwargs) + return event + + +# --------------------------------------------------------------------------- +# ChannelRoutingStrategy +# --------------------------------------------------------------------------- + + +async def test_channel_routing_match_returns_handlers_and_stops(): + uuid = "test-uuid-123" + registry = {f"{uuid}:CHANNEL_STATE": [MagicMock()]} + strategy = ChannelRoutingStrategy(registry) + event = make_event(**{"Unique-ID": uuid, "Event-Name": "CHANNEL_STATE"}) + + handlers, should_stop = await strategy.route(event) + + assert handlers == registry[f"{uuid}:CHANNEL_STATE"] + assert should_stop is True + + +async def test_channel_routing_miss_returns_empty_and_false(): + strategy = ChannelRoutingStrategy({}) + event = make_event(**{"Unique-ID": "unknown-uuid", "Event-Name": "CHANNEL_STATE"}) + + handlers, should_stop = await strategy.route(event) + + assert handlers == [] + assert should_stop is False + + +async def test_channel_routing_custom_event_uses_subclass(): + uuid = "abc" + subclass = "mod_audio_stream::play" + registry = {f"{uuid}:{subclass}": [MagicMock()]} + strategy = ChannelRoutingStrategy(registry) + event = make_event( + **{"Unique-ID": uuid, "Event-Name": "CUSTOM", "Event-Subclass": subclass} + ) + + handlers, should_stop = await strategy.route(event) + + assert len(handlers) == 1 + assert should_stop is True + + +async def test_channel_routing_no_uuid_returns_empty(): + strategy = ChannelRoutingStrategy({"x:HEARTBEAT": [MagicMock()]}) + event = make_event(**{"Event-Name": "HEARTBEAT"}) + + handlers, should_stop = await strategy.route(event) + + assert handlers == [] + assert should_stop is False + + +# --------------------------------------------------------------------------- +# GlobalRoutingStrategy +# --------------------------------------------------------------------------- + + +async def test_global_routing_match_by_event_name(): + handler = MagicMock() + strategy = GlobalRoutingStrategy({"HEARTBEAT": [handler]}) + event = make_event(**{"Event-Name": "HEARTBEAT"}) + + handlers, should_stop = await strategy.route(event) + + assert handler in handlers + assert should_stop is False + + +async def test_global_routing_wildcard_matches_any_event(): + wildcard = MagicMock() + strategy = GlobalRoutingStrategy({"*": [wildcard]}) + event = make_event(**{"Event-Name": "CHANNEL_CREATE"}) + + handlers, should_stop = await strategy.route(event) + + assert wildcard in handlers + assert should_stop is False + + +async def test_global_routing_specific_and_wildcard_combined(): + specific = MagicMock() + wildcard = MagicMock() + strategy = GlobalRoutingStrategy({"HEARTBEAT": [specific], "*": [wildcard]}) + event = make_event(**{"Event-Name": "HEARTBEAT"}) + + handlers, _ = await strategy.route(event) + + assert specific in handlers + assert wildcard in handlers + + +async def test_global_routing_miss_returns_empty(): + strategy = GlobalRoutingStrategy({"HEARTBEAT": [MagicMock()]}) + event = make_event(**{"Event-Name": "CHANNEL_HANGUP"}) + + handlers, should_stop = await strategy.route(event) + + assert handlers == [] + assert should_stop is False + + +async def test_global_routing_no_event_name_returns_empty(): + strategy = GlobalRoutingStrategy({"*": [MagicMock()]}) + event = ESLEvent() + + handlers, should_stop = await strategy.route(event) + + assert handlers == [] + assert should_stop is False + + +# --------------------------------------------------------------------------- +# CompositeRoutingStrategy +# --------------------------------------------------------------------------- + + +async def test_composite_stops_at_first_strategy_with_handlers(): + h1 = MagicMock() + h2 = MagicMock() + uuid = "u1" + channel_registry = {f"{uuid}:CHANNEL_STATE": [h1]} + global_handlers = {"CHANNEL_STATE": [h2]} + + composite = CompositeRoutingStrategy( + [ + ChannelRoutingStrategy(channel_registry), + GlobalRoutingStrategy(global_handlers), + ] + ) + event = make_event(**{"Unique-ID": uuid, "Event-Name": "CHANNEL_STATE"}) + + handlers, should_stop = await composite.route(event) + + assert h1 in handlers + assert h2 not in handlers + assert should_stop is True + + +async def test_composite_falls_through_to_next_strategy(): + h2 = MagicMock() + composite = CompositeRoutingStrategy( + [ + ChannelRoutingStrategy({}), + GlobalRoutingStrategy({"HEARTBEAT": [h2]}), + ] + ) + event = make_event(**{"Event-Name": "HEARTBEAT"}) + + handlers, should_stop = await composite.route(event) + + assert h2 in handlers + assert should_stop is False + + +async def test_composite_no_match_returns_empty(): + composite = CompositeRoutingStrategy( + [ + ChannelRoutingStrategy({}), + GlobalRoutingStrategy({}), + ] + ) + event = make_event(**{"Event-Name": "HEARTBEAT"}) + + handlers, should_stop = await composite.route(event) + + assert handlers == [] + assert should_stop is False + + +async def test_composite_propagates_should_stop_true(): + """A strategy returning (handlers, True) must propagate should_stop=True.""" + h = MagicMock() + uuid = "u2" + channel_registry = {f"{uuid}:CHANNEL_STATE": [h]} + composite = CompositeRoutingStrategy( + [ + ChannelRoutingStrategy(channel_registry), + ] + ) + event = make_event(**{"Unique-ID": uuid, "Event-Name": "CHANNEL_STATE"}) + + _, should_stop = await composite.route(event) + + assert should_stop is True + + +# --------------------------------------------------------------------------- +# dispatch_to_handlers +# --------------------------------------------------------------------------- + + +async def test_dispatch_creates_tasks_in_set(): + done = asyncio.Event() + + async def handler(event: ESLEvent) -> None: + done.set() + + task_set: set[Task[Any]] = set() + event = make_event(**{"Event-Name": "HEARTBEAT"}) + + dispatch_to_handlers([handler], event, task_set) + + # At least one task should be in the set immediately after dispatch + assert len(task_set) == 1 + + await asyncio.wait_for(done.wait(), timeout=5) + + +async def test_dispatch_task_removed_from_set_after_completion(): + done = asyncio.Event() + + async def handler(event: ESLEvent) -> None: + done.set() + + task_set: set[Task[Any]] = set() + event = make_event(**{"Event-Name": "HEARTBEAT"}) + + dispatch_to_handlers([handler], event, task_set) + await asyncio.wait_for(done.wait(), timeout=5) + # Give the done callback a chance to run + await asyncio.sleep(0) + + assert len(task_set) == 0 + + +async def test_dispatch_exception_in_handler_is_logged_not_propagated(): + from unittest.mock import patch + + async def bad_handler(event: ESLEvent) -> None: + raise ValueError("handler exploded") + + task_set: set[Task[Any]] = set() + event = make_event(**{"Event-Name": "HEARTBEAT"}) + + with patch("genesis.protocol.routing.dispatcher.logger") as mock_log: + dispatch_to_handlers([bad_handler], event, task_set) + tasks = list(task_set) + for task in tasks: + try: + await asyncio.wait_for(asyncio.shield(task), timeout=5) + except Exception: + pass + await asyncio.sleep(0) + + error_calls = mock_log.error.call_args_list + assert any("handler exploded" in str(call) for call in error_calls) + + +async def test_dispatch_sync_handler_via_to_thread(): + results = [] + + def sync_handler(event: ESLEvent) -> None: + results.append(event.get("Event-Name")) + + task_set: set[Task[Any]] = set() + event = make_event(**{"Event-Name": "HEARTBEAT"}) + + dispatch_to_handlers([sync_handler], event, task_set) + tasks = list(task_set) + await asyncio.gather(*tasks) + + assert results == ["HEARTBEAT"] + + +async def test_dispatch_no_task_set_does_not_raise(): + """dispatch_to_handlers must work without a task_set (backwards compat).""" + done = asyncio.Event() + + async def handler(event: ESLEvent) -> None: + done.set() + + event = make_event(**{"Event-Name": "HEARTBEAT"}) + dispatch_to_handlers([handler], event) + await asyncio.wait_for(done.wait(), timeout=5)