Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/ezmsg/core/backpressure.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import time

from uuid import UUID

from typing import Literal

from .profiling import LeaseDurationTelemetry

class BufferLease:
"""
Expand Down Expand Up @@ -83,7 +85,9 @@ class Backpressure:
empty: asyncio.Event
pressure: int

def __init__(self, num_buffers: int) -> None:
def __init__(
self, num_buffers: int, telemetry: LeaseDurationTelemetry | None = None
) -> None:
"""
Initialize backpressure management for the specified number of buffers.

Expand All @@ -94,6 +98,7 @@ def __init__(self, num_buffers: int) -> None:
self.empty = asyncio.Event()
self.empty.set()
self.pressure = 0
self._telemetry = telemetry

@property
def is_empty(self) -> bool:
Expand Down Expand Up @@ -138,6 +143,8 @@ def lease(self, uuid: UUID, buf_idx: int) -> None:
self.pressure += 1
self.buffers[buf_idx].add(uuid)
self.empty.clear()
if self._telemetry is not None:
self._telemetry.on_lease(uuid, buf_idx, time.perf_counter())

def _free(self, uuid: UUID, buf_idx: int) -> None:
"""
Expand All @@ -152,6 +159,8 @@ def _free(self, uuid: UUID, buf_idx: int) -> None:
self.buffers[buf_idx].remove(uuid)
if self.buffers[buf_idx].is_empty:
self.pressure -= 1
if self._telemetry is not None:
self._telemetry.on_free(uuid, buf_idx, time.perf_counter())
except KeyError:
pass

Expand All @@ -170,6 +179,8 @@ def free(self, uuid: UUID, buf_idx: int | None = None) -> None:
if buf_idx is None:
for idx in range(len(self.buffers)):
self._free(uuid, idx)
if self._telemetry is not None:
self._telemetry.on_free(uuid, None, time.perf_counter())
else:
self._free(uuid, buf_idx)

Expand Down
15 changes: 13 additions & 2 deletions src/ezmsg/core/channelmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def register(
client_id: UUID,
queue: NotificationQueue,
graph_address: AddressType | None = None,
handle: str | None = None,
) -> Channel:
"""
Acquire the channel associated with a particular publisher, creating it if necessary
Expand All @@ -49,10 +50,14 @@ async def register(
:type queue: asyncio.Queue[tuple[UUID, int]]
:param graph_address: The address to the GraphServer that the requested publisher is managed by
:type graph_address: AddressType | None
:param handle: Optional label to associate with the subscribing client for profiling output
:type handle: str | None
:return: A Channel for retreiving messages from the requested Publisher
:rtype: Channel
"""
return await self._register(pub_id, client_id, queue, graph_address, None)
return await self._register(
pub_id, client_id, queue, graph_address, None, handle=handle
)

async def register_local_pub(
self,
Expand Down Expand Up @@ -85,6 +90,7 @@ async def _register(
queue: NotificationQueue | None = None,
graph_address: AddressType | None = None,
local_backpressure: Backpressure | None = None,
handle: str | None = None,
) -> Channel:
graph_address = _ensure_address(graph_address)
try:
Expand All @@ -94,7 +100,12 @@ async def _register(
channels = self._registry.get(graph_address, dict())
channels[pub_id] = channel
self._registry[graph_address] = channels
channel.register_client(client_id, queue, local_backpressure)
channel.register_client(
client_id,
queue,
local_backpressure,
handle=handle,
)
return channel

async def unregister(
Expand Down
17 changes: 16 additions & 1 deletion src/ezmsg/core/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def cmdline() -> None:
parser.add_argument(
"command",
help="command for ezmsg",
choices=["serve", "start", "shutdown", "graphviz", "mermaid"],
choices=["serve", "start", "shutdown", "graphviz", "mermaid", "profile"],
)

parser.add_argument("--address", help="Address for GraphServer", default=None)
Expand Down Expand Up @@ -70,12 +70,21 @@ def cmdline() -> None:
action="store_true",
)

parser.add_argument(
"-w",
"--window",
help="Profiling window (seconds) for the 'profile' command",
type=float,
default=None,
)

class Args:
command: str
address: str | None
target: str
compact: int | None
nobrowser: bool
window: float | None

args = parser.parse_args(namespace=Args)

Expand All @@ -93,6 +102,7 @@ class Args:
args.target,
args.compact,
args.nobrowser,
args.window,
)
)

Expand All @@ -103,6 +113,7 @@ async def run_command(
target: str = "live",
compact: int | None = None,
nobrowser: bool = False,
window: float | None = None,
) -> None:
"""
Run an ezmsg command with the specified parameters.
Expand Down Expand Up @@ -166,6 +177,10 @@ async def run_command(
f"Could not issue shutdown command to GraphServer @ {graph_service.address}; server not running?"
)

elif cmd == "profile":
profile = await graph_service.profile(window)
print(json.dumps(profile, indent=2))

elif cmd in ["graphviz", "mermaid"]:
graph_out = await graph_service.get_formatted_graph(
fmt=cmd, compact_level=compact
Expand Down
104 changes: 104 additions & 0 deletions src/ezmsg/core/graphserver.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import asyncio
import json
import logging
import pickle
import os
import socket
import threading
from typing import Any
from contextlib import suppress
from uuid import UUID, uuid1


from . import __version__
from .dag import DAG, CyclicException
from .graph_util import get_compactified_graph, graph_string, prune_graph_connections
from .profiling import PROFILE_WINDOW_S
from .netprotocol import (
Address,
Command,
Expand Down Expand Up @@ -78,6 +81,7 @@ def __init__(self, **kwargs) -> None:
self.clients = {}
self._client_tasks = {}
self.shms = {}
self._profile_requests: dict[UUID, asyncio.Future] = {}

@property
def address(self) -> Address:
Expand Down Expand Up @@ -119,6 +123,10 @@ async def _shutdown_async(self) -> None:
with suppress(asyncio.CancelledError):
await asyncio.gather(*self._client_tasks.values(), return_exceptions=True)
self._client_tasks.clear()
for fut in list(self._profile_requests.values()):
if not fut.done():
fut.cancel()
self._profile_requests.clear()

# Cancel SHM leases
for info in self.shms.values():
Expand Down Expand Up @@ -338,6 +346,14 @@ async def api(
writer.write(uint64_to_bytes(len(dag_bytes)) + dag_bytes)
writer.write(Command.COMPLETE.value)

elif req == Command.PROFILE.value:
window_ms = await read_int(reader)
profile = await self._collect_profiles(window_ms / 1000.0)
writer.write(Command.PROFILE_DATA.value)
writer.write(encode_str(json.dumps(profile)))
writer.write(Command.COMPLETE.value)
await writer.drain()

else:
logger.warning(f"GraphServer received unknown command {req}")

Expand Down Expand Up @@ -380,6 +396,15 @@ async def _handle_client(

if req == Command.COMPLETE.value:
self.clients[client_id].set_sync()
elif req == Command.PROFILE_DATA.value:
payload = await read_str(reader)
future = self._profile_requests.pop(client_id, None)
if future is not None and not future.done():
try:
future.set_result(json.loads(payload))
except json.JSONDecodeError as e:
future.set_exception(e)
self.clients[client_id].set_sync()

except (ConnectionResetError, BrokenPipeError) as e:
logger.debug(f"Client {client_id} disconnected from GraphServer: {e}")
Expand All @@ -388,6 +413,9 @@ async def _handle_client(
# Ensure any waiter on this client unblocks
# with suppress(Exception):
self.clients[client_id].set_sync()
future = self._profile_requests.pop(client_id, None)
if future is not None and not future.done():
future.cancel()
self.clients.pop(client_id, None)
await close_stream_writer(writer)

Expand Down Expand Up @@ -428,6 +456,64 @@ def _downstream_subs(self, topic: str) -> list[SubscriberInfo]:
downstream_topics = self.graph.downstream(topic)
return [sub for sub in self._subscribers() if sub.topic in downstream_topics]

async def _profile_client(
self, info: ClientInfo, window_s: float
) -> dict[str, Any] | None:
loop = asyncio.get_running_loop()
fut: asyncio.Future[dict[str, Any]] = loop.create_future()
self._profile_requests[info.id] = fut

try:
async with info.sync_writer() as writer:
writer.write(Command.PROFILE.value)
writer.write(uint64_to_bytes(int(window_s * 1000)))
await writer.drain()

return await asyncio.wait_for(fut, timeout=1.0)

except asyncio.TimeoutError:
logger.debug(f"Profile request to {info.id} timed out")
return None

finally:
self._profile_requests.pop(info.id, None)

async def _collect_profiles(self, window_s: float) -> dict[str, Any]:
client_infos = [
info
for info in self.clients.values()
if isinstance(info, (PublisherInfo, ChannelInfo))
]
results = await asyncio.gather(
*[self._profile_client(info, window_s) for info in client_infos],
return_exceptions=True,
)

publishers: list[dict[str, Any]] = []
channels: list[dict[str, Any]] = []

for info, result in zip(client_infos, results):
if isinstance(result, Exception) or result is None:
continue

if isinstance(info, PublisherInfo) and result.get("type") == "publisher":
publishers.append(result)

elif isinstance(info, ChannelInfo) and result.get("type") == "channel":
try:
pub_info = self.clients.get(UUID(result["pub_id"]), None)
if isinstance(pub_info, PublisherInfo):
result.setdefault("topic", pub_info.topic)
except Exception:
...
channels.append(result)

return {
"window_s": window_s,
"publishers": publishers,
"channels": channels,
}


class GraphService:
ADDR_ENV = GRAPHSERVER_ADDR_ENV
Expand Down Expand Up @@ -544,6 +630,24 @@ async def dag(self, timeout: float | None = None) -> DAG:
await close_stream_writer(writer)
return dag

async def profile(self, window_s: float | None = None) -> dict[str, Any]:
reader, writer = await self.open_connection()
if window_s is None:
window_s = PROFILE_WINDOW_S
writer.write(Command.PROFILE.value)
writer.write(uint64_to_bytes(int(window_s * 1000)))
await writer.drain()

response = await reader.read(1)
if response != Command.PROFILE_DATA.value:
await close_stream_writer(writer)
raise ValueError("Unexpected response to profile request")

payload = await read_str(reader)
await reader.read(1) # COMPLETE
await close_stream_writer(writer)
return json.loads(payload)

async def get_formatted_graph(
self,
fmt: str,
Expand Down
Loading