Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 123 additions & 72 deletions asab/api/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


L = logging.getLogger(__name__)
LogObsolete = logging.getLogger('OBSOLETE')


class DiscoveryService(Service):
Expand All @@ -34,13 +35,15 @@ def __init__(self, app, zkc, service_name="asab.DiscoveryService") -> None:
from .internal_auth import InternalAuth
self.InternalAuth = InternalAuth(app, zkc)

self.BasePath = "/" + self.ZooKeeperContainer.Path + "/run"

self._advertised_cache = dict()
self._advertised_raw = dict()

self._cache_lock = asyncio.Lock()
self._ready_event = asyncio.Event()

self.App.PubSub.subscribe("Application.tick/300!", self._on_tick)
self.App.PubSub.subscribe("Application.tick/600!", self._on_tick600)
self.App.PubSub.subscribe("ZooKeeperContainer.state/CONNECTED!", self._on_zk_ready)


Expand All @@ -49,13 +52,62 @@ async def initialize(self, app):
await self.InternalAuth.initialize(app)


def _on_tick(self, msg):
def _on_tick600(self, _msg):
# Full rescan of the advertised instances every 10 minutes
self.App.TaskService.schedule(self._rescan_advertised_instances())


def _on_zk_ready(self, msg, zkcontainer):
if zkcontainer != self.ZooKeeperContainer:
return

self.App.TaskService.schedule(self._rescan_advertised_instances())

# Install a persistent watch on the base path to detect changes in the advertised instances everytime the ZooKeeper connection is established
zkcontainer.ZooKeeper.Client.add_watch(
self.BasePath,
self._on_change_zookeeper_thread,
kazoo.protocol.states.AddWatchMode.PERSISTENT_RECURSIVE
)


def _on_change_zookeeper_thread(self, event):
if event.state != 'CONNECTED':
return

if event.path is None:
return

# Handle the change event in the thread-safe manner in the main event loop thread
self.App.TaskService.schedule_threadsafe(self._on_change(event.path[len(self.BasePath) + 1:], event.type))


async def _on_change(self, item, event_type):
async with self._cache_lock:

if event_type == 'CREATED' or event_type == 'CHANGED':
# The item is new or changed - read the data and update the cache
try:
data, _stat = self.ZooKeeperContainer.ZooKeeper.Client.get(self.BasePath + '/' + item)
self._advertised_raw[item] = json.loads(data)
except (kazoo.exceptions.SessionExpiredError, kazoo.exceptions.ConnectionLoss):
L.warning("Connection to ZooKeeper lost. Discovery Service could not fetch up-to-date state of the cluster services.")
return
except kazoo.exceptions.NoNodeError:
return

elif event_type == 'DELETED':
# The item is deleted - remove it from the cache
prev = self._advertised_raw.pop(item, None)
if prev is None:
return

else:
L.warning("Unexpected event type: {}".format(event_type))
return

def _on_zk_ready(self, msg, zkc):
if zkc == self.ZooKeeperContainer:
self.App.TaskService.schedule(self._rescan_advertised_instances())
# Apply the changes to the cache
await self._apply_advertised_raw()
Comment on lines +85 to +110
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Blocking ZooKeeper call on the async event loop.

Line 91 calls self.ZooKeeperContainer.ZooKeeper.Client.get() directly inside an async method. Kazoo's client methods are synchronous and blocking—this will block the entire asyncio event loop until the ZooKeeper response arrives, degrading throughput and responsiveness.

Use ProactorService.execute() to run the blocking call in a thread pool, consistent with how _iter_zk_items handles it:

Proposed fix
 async def _on_change(self, item, event_type):
 	async with self._cache_lock:
 
 		if event_type == 'CREATED' or event_type == 'CHANGED':
 			# The item is new or changed - read the data and update the cache
 			try:
-				data, _stat = self.ZooKeeperContainer.ZooKeeper.Client.get(self.BasePath + '/' + item)
-				self._advertised_raw[item] = json.loads(data)
+				def fetch_item():
+					return self.ZooKeeperContainer.ZooKeeper.Client.get(self.BasePath + '/' + item)
+				data, _stat = await self.ProactorService.execute(fetch_item)
+				self._advertised_raw[item] = json.loads(data)
 			except (kazoo.exceptions.SessionExpiredError, kazoo.exceptions.ConnectionLoss):

Additionally, event_type comparisons at lines 88, 99, 105 use string literals. Consider using kazoo.protocol.states.EventType enum values for consistency (e.g., EventType.CREATED).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@asab/api/discovery.py` around lines 85 - 110, The _on_change coroutine
currently calls the blocking kazoo Client.get() directly (in _on_change) which
will block the asyncio loop; wrap the blocking call in ProactorService.execute
(same pattern used in _iter_zk_items) so ZooKeeper.Client.get(self.BasePath +
'/' + item) runs in the threadpool and returns its result to the async function,
preserving the existing exception handling for SessionExpiredError,
ConnectionLoss and NoNodeError; also replace string literal event_type checks in
_on_change with kazoo.protocol.states.EventType (e.g., EventType.CREATED,
EventType.CHANGED, EventType.DELETED) to make the comparisons consistent.



async def locate(self, instance_id: str = None, **kwargs) -> set:
Expand Down Expand Up @@ -87,6 +139,7 @@ async def locate(self, instance_id: str = None, **kwargs) -> set:
in await self._locate(locate_params)
])


async def _locate(self, locate_params) -> typing.Set[typing.Tuple]:
"""
Locate service instances based on their instance ID or service ID.
Expand Down Expand Up @@ -114,7 +167,6 @@ async def _locate(self, locate_params) -> typing.Set[typing.Tuple]:


async def discover(self) -> typing.Dict[str, typing.Dict[str, typing.Set[typing.Tuple]]]:
# We need to make a copy of the cache so that the caller can't modify our cache.
await asyncio.wait_for(self._ready_event.wait(), 600)
return self._advertised_cache

Expand All @@ -130,7 +182,7 @@ async def get_advertised_instances(self) -> typing.List[typing.Dict]:
Returns a list of dictionaries. Each dictionary represents an advertised instance
obtained by iterating over the items in the `/run` path in ZooKeeper.
"""
# TODO: an obsolete log for this method
LogObsolete.warning("get_advertised_instances() is deprecated. Use discover() method instead. This method will be removed after Sep 2026")
advertised = []
for item, item_data in await self._iter_zk_items():
item_data['ephemeral_id'] = item
Expand Down Expand Up @@ -166,92 +218,96 @@ async def _rescan_advertised_instances(self):
...
}
"""

if self._cache_lock.locked():
# Only one rescan / cache update at a time
return


async with self._cache_lock:
try:
prev_keys = set(self._advertised_raw.keys())
for item, item_data in await self._iter_zk_items():
self._advertised_raw[item] = item_data
prev_keys.discard(item)
for item in prev_keys:
self._advertised_raw.pop(item, None)
except asyncio.CancelledError:
raise
except Exception:
L.exception("Error when scanning advertised instances")
return

advertised = {
"instance_id": {},
"service_id": {},
}
await self._apply_advertised_raw()

advertised_raw = {}

try:
for item, item_data in await self._iter_zk_items():
async def _apply_advertised_raw(self):
advertised = {
"instance_id": {},
"service_id": {},
}

advertised_raw[item] = item_data
async with self._cache_lock:
Comment on lines +240 to +249
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Race window between releasing lock and _apply_advertised_raw re-acquiring it.

In _rescan_advertised_instances, the _cache_lock is released at the end of the async with block (after line 238), then _apply_advertised_raw is called at line 240 which re-acquires the lock at line 249. Similarly, _on_change releases the lock after modifying _advertised_raw (line 107) before calling _apply_advertised_raw (line 110).

This creates a window where another coroutine can modify _advertised_raw between unlock and relock, potentially causing inconsistent state.

Consider either:

  1. Keeping _apply_advertised_raw call inside the lock, or
  2. Making _apply_advertised_raw work on a snapshot of _advertised_raw
Option 1: Call _apply_advertised_raw inside the lock
 	async with self._cache_lock:
 		try:
 			prev_keys = set(self._advertised_raw.keys())
 			for item, item_data in await self._iter_zk_items():
 				self._advertised_raw[item] = item_data
 				prev_keys.discard(item)
 			for item in prev_keys:
 				self._advertised_raw.pop(item, None)
 		except asyncio.CancelledError:
 			raise
 		except Exception:
 			L.exception("Error when scanning advertised instances")
 			return
+		await self._apply_advertised_raw()
-
-	await self._apply_advertised_raw()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@asab/api/discovery.py` around lines 240 - 249, Race window: callers
(_rescan_advertised_instances and _on_change) release _cache_lock then call
_apply_advertised_raw which re-acquires the lock, allowing concurrent mutation
of _advertised_raw; fix by having _apply_advertised_raw operate on a snapshot
passed in by callers. Change _apply_advertised_raw to accept an advertised
snapshot argument (e.g. advertised_snapshot) and have callers
(_rescan_advertised_instances and _on_change) create a shallow/deep copy of
_advertised_raw while holding _cache_lock and pass that copy to
_apply_advertised_raw; remove or avoid re-acquiring _cache_lock inside
_apply_advertised_raw so it processes the stable snapshot without race.

for item_data in self._advertised_raw.values():
instance_id = item_data.get("instance_id")
service_id = item_data.get("service_id")
discovery: typing.Dict[str, list] = item_data.get("discovery", {})

if instance_id is not None:
discovery["instance_id"] = [instance_id]

if service_id is not None:
discovery["service_id"] = [service_id]

instance_id = item_data.get("instance_id")
service_id = item_data.get("service_id")
discovery: typing.Dict[str, list] = item_data.get("discovery", {})
host = item_data.get("host")
if host is None:
continue

if instance_id is not None:
discovery["instance_id"] = [instance_id]
web = item_data.get("web")
if web is None:
continue

if service_id is not None:
discovery["service_id"] = [service_id]
for i in web:

host = item_data.get("host")
if host is None:
try:
ip = i[0]
port = i[1]
except (IndexError, TypeError, KeyError):
L.error("Unexpected format of 'web' section in advertised data: '{}'".format(web))
continue

web = item_data.get("web")
if web is None:
if ip == "0.0.0.0":
family = socket.AF_INET
elif ip == "::":
family = socket.AF_INET6
else:
continue

for i in web:

try:
ip = i[0]
port = i[1]
except KeyError:
L.error("Unexpected format of 'web' section in advertised data: '{}'".format(web))
continue

if ip == "0.0.0.0":
family = socket.AF_INET
elif ip == "::":
family = socket.AF_INET6
else:
continue

if discovery is not None:
for id_type, ids in discovery.items():
if advertised.get(id_type) is None:
advertised[id_type] = {}

for identifier in ids:
if identifier is not None:
if advertised[id_type].get(identifier) is None:
advertised[id_type][identifier] = {(host, port, family)}
else:
advertised[id_type][identifier].add((host, port, family))
except Exception:
L.exception("Error when scanning advertised instances")
return
if discovery is not None:
for id_type, ids in discovery.items():
if advertised.get(id_type) is None:
advertised[id_type] = {}

for identifier in ids:
if identifier is not None:
if advertised[id_type].get(identifier) is None:
advertised[id_type][identifier] = {(host, port, family)}
else:
advertised[id_type][identifier].add((host, port, family))

# TODO: Transform _advertised_cache and _advertised_raw into read-only structures
self._advertised_cache = advertised
self._advertised_raw = advertised_raw

self._ready_event.set()
self._ready_event.set()


async def _iter_zk_items(self):
base_path = self.ZooKeeperContainer.Path + "/run"

def get_items():
result = []
try:
# Create the base path if it does not exist
if not self.ZooKeeperContainer.ZooKeeper.Client.exists(base_path):
self.ZooKeeperContainer.ZooKeeper.Client.create(base_path, b'', makepath=True)
if not self.ZooKeeperContainer.ZooKeeper.Client.exists(self.BasePath):
self.ZooKeeperContainer.ZooKeeper.Client.create(self.BasePath, b'', makepath=True)

items = self.ZooKeeperContainer.ZooKeeper.Client.get_children(base_path, watch=self._on_change_threadsafe)
items = self.ZooKeeperContainer.ZooKeeper.Client.get_children(self.BasePath)

except (kazoo.exceptions.SessionExpiredError, kazoo.exceptions.ConnectionLoss):
L.warning("Connection to ZooKeeper lost. Discovery Service could not fetch up-to-date state of the cluster services.")
Expand All @@ -263,7 +319,7 @@ def get_items():

for item in items:
try:
data, stat = self.ZooKeeperContainer.ZooKeeper.Client.get(base_path + '/' + item, watch=self._on_change_threadsafe)
data, _stat = self.ZooKeeperContainer.ZooKeeper.Client.get(self.BasePath + '/' + item)
result.append((item, json.loads(data)))
except (kazoo.exceptions.SessionExpiredError, kazoo.exceptions.ConnectionLoss):
L.warning("Connection to ZooKeeper lost. Discovery Service could not fetch up-to-date state of the cluster services.")
Expand All @@ -280,11 +336,6 @@ def get_items():
return result


def _on_change_threadsafe(self, watched_event):
# Runs on a thread, returns the process back to the main thread
self.App.TaskService.schedule_threadsafe(self._rescan_advertised_instances())


def session(
self,
base_url: typing.Optional[str] = None,
Expand Down Expand Up @@ -417,7 +468,7 @@ async def resolve(self, hostname: str, port: int = 0, family: int = socket.AF_IN
hosts.extend(resolved)

if len(hosts) == 0:
raise NotDiscoveredError("Failed to resolve any of the hosts for '{}' / '{}'.".format(hostname, ','.join(x[0] for x in set(x[0] for x in located_instances))))
raise NotDiscoveredError("Failed to resolve any of the hosts for '{}' / '{}'.".format(hostname, ','.join(x for x in set(x[0] for x in located_instances))))

return hosts

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def run(self):
install_requires=[
'aiohttp>=3.8.3,<4',
'fastjsonschema>=2.16.2,<3',
'kazoo>=2.9.0,<3',
'kazoo @ git+https://github.com/TeskaLabs/kazoo.git',
'PyYAML>=6.0,<7',
],
extras_require={
Expand Down
Loading