From 9181149ce3308c078850ea470cfa45984c6c14db Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 6 Feb 2025 19:14:40 +0100 Subject: [PATCH 1/5] feat: add stream stats endpoint This commit adds a new stream stats endpoint which can be used to retrieve the fps metrics in a way that doesn't affect performance. --- server/app.py | 92 +++++++++++++++++++++++++++++++++++++++++++++---- server/utils.py | 71 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 154 insertions(+), 9 deletions(-) diff --git a/server/app.py b/server/app.py index 2a896e57..2547ba44 100644 --- a/server/app.py +++ b/server/app.py @@ -12,12 +12,14 @@ RTCConfiguration, RTCIceServer, MediaStreamTrack, - RTCDataChannel, ) +import threading +import av from aiortc.rtcrtpsender import RTCRtpSender from aiortc.codecs import h264 from pipeline import Pipeline -from utils import patch_loop_datagram +from utils import patch_loop_datagram, StreamStats +import time logger = logging.getLogger(__name__) logging.getLogger('aiortc.rtcrtpsender').setLevel(logging.WARNING) @@ -29,16 +31,82 @@ class VideoStreamTrack(MediaStreamTrack): + """video stream track that processes video frames using a pipeline. + + Attributes: + kind (str): The kind of media, which is "video" for this class. + track (MediaStreamTrack): The underlying media stream track. + pipeline (Pipeline): The processing pipeline to apply to each video frame. + """ + kind = "video" - def __init__(self, track: MediaStreamTrack, pipeline): + def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): + """Initialize the VideoStreamTrack. + + Args: + track: The underlying media stream track. + pipeline: The processing pipeline to apply to each video frame. + """ super().__init__() self.track = track self.pipeline = pipeline - - async def recv(self): - frame = await self.track.recv() - return await self.pipeline(frame) + self._frame_count = 0 + self._start_time = time.monotonic() + self._lock = threading.Lock() + self._fps = 0.0 + self._running = True + self._start_fps_thread() + + def _start_fps_thread(self): + """Start a separate thread to calculate FPS periodically.""" + self.fps_thread = threading.Thread(target=self._calculate_fps_loop, daemon=True) + self.fps_thread.start() + + def _calculate_fps_loop(self): + """Loop to calculate FPS periodically.""" + while self._running: + time.sleep(1) # Calculate FPS every second. + with self._lock: + current_time = time.monotonic() + time_diff = current_time - self._start_time + if time_diff > 0: + self._fps = self._frame_count / time_diff + + # Reset start_time and frame_count for the next interval. + self._start_time = current_time + self._frame_count = 0 + + def stop(self): + """Stop the FPS calculation thread.""" + self._running = False + self.fps_thread.join() + + @property + def fps(self) -> float: + """Get the current output frames per second (FPS). + + Returns: + The current output FPS. + """ + with self._lock: + return self._fps + + async def recv(self) -> av.VideoFrame: + """Receive and process a video frame. Called by the WebRTC library when a frame + is received. + + Returns: + The processed video frame. + """ + input_frame = await self.track.recv() + processed_frame = await self.pipeline(input_frame) + + # Increment frame count for FPS calculation. + with self._lock: + self._frame_count += 1 + + return processed_frame def force_codec(pc, sender, forced_codec): @@ -156,6 +224,10 @@ def on_track(track): tracks["video"] = videoTrack sender = pc.addTrack(videoTrack) + # Store video track in app for stats. + stream_id = track.id + request.app["video_tracks"][stream_id] = videoTrack + codec = "video/H264" force_codec(pc, sender, codec) @@ -207,6 +279,7 @@ async def on_startup(app: web.Application): cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True ) app["pcs"] = set() + app["video_tracks"] = {} async def on_shutdown(app: web.Application): @@ -251,4 +324,9 @@ async def on_shutdown(app: web.Application): app.router.add_post("/prompt", set_prompt) app.router.add_get("/", health) + # Add routes for getting stream statistics. + stream_stats = StreamStats(app) + app.router.add_get("/stats", stream_stats.get_stats) + app.router.add_get("/stats/{stream_id}", stream_stats.get_stats_by_id) + web.run_app(app, host=args.host, port=int(args.port)) diff --git a/server/utils.py b/server/utils.py index db263f88..858f431d 100644 --- a/server/utils.py +++ b/server/utils.py @@ -1,9 +1,13 @@ +"""Utility functions for the server.""" + import asyncio import random import types import logging - -from typing import List, Tuple +import json +from aiohttp import web +from aiortc import MediaStreamTrack +from typing import List, Tuple, Any, Dict logger = logging.getLogger(__name__) @@ -48,3 +52,66 @@ async def create_datagram_endpoint( loop.create_datagram_endpoint = types.MethodType(create_datagram_endpoint, loop) loop._patch_done = True + + +class StreamStats: + """Class to get stream statistics.""" + + def __init__(self, app: web.Application): + """Initialize the StreamStats class.""" + self._app = app + + def get_video_track_stats(self, video_track: MediaStreamTrack) -> Dict[str, Any]: + """Get statistics for a video track. + + Args: + video_track: The VideoStreamTrack instance. + + Returns: + A dictionary containing the statistics. + """ + return { + "fps": video_track.fps, + } + + async def get_stats(self, _) -> web.Response: + """Get the current stream statistics for all streams. + + Args: + request: The HTTP GET request. + + Returns: + The HTTP response containing the statistics. + """ + video_tracks = self._app.get("video_tracks", {}) + all_stats = { + stream_id: self.get_video_track_stats(track) + for stream_id, track in video_tracks.items() + } + + return web.Response( + content_type="application/json", + text=json.dumps(all_stats), + ) + + async def get_stats_by_id(self, request: web.Request) -> web.Response: + """Get the statistics for a specific stream by ID. + + Args: + request: The HTTP GET request. + + Returns: + The HTTP response containing the statistics. + """ + stream_id = request.match_info.get("stream_id") + video_track = self._app["video_tracks"].get(stream_id) + + if video_track: + stats = self.get_video_track_stats(video_track) + else: + stats = {"error": "Stream not found"} + + return web.Response( + content_type="application/json", + text=json.dumps(stats), + ) From 19fc0acb79820d31004508195e8010ef12d833c2 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Fri, 7 Feb 2025 08:10:34 +0100 Subject: [PATCH 2/5] refactor: add `live` prefix This commit ensures that the paths are also available under the `live` prefix. This will allow consistency with the hosted experience and improve the user experience. --- server/app.py | 14 ++++++++++---- server/utils.py | 13 +++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/server/app.py b/server/app.py index 2547ba44..43f7ba3b 100644 --- a/server/app.py +++ b/server/app.py @@ -18,7 +18,7 @@ from aiortc.rtcrtpsender import RTCRtpSender from aiortc.codecs import h264 from pipeline import Pipeline -from utils import patch_loop_datagram, StreamStats +from utils import patch_loop_datagram, StreamStats, add_prefix_to_app_routes import time logger = logging.getLogger(__name__) @@ -320,13 +320,19 @@ async def on_shutdown(app: web.Application): app.on_startup.append(on_startup) app.on_shutdown.append(on_shutdown) + app.router.add_get("/", health) + + # WebRTC signalling and control routes. app.router.add_post("/offer", offer) app.router.add_post("/prompt", set_prompt) - app.router.add_get("/", health) # Add routes for getting stream statistics. stream_stats = StreamStats(app) - app.router.add_get("/stats", stream_stats.get_stats) - app.router.add_get("/stats/{stream_id}", stream_stats.get_stats_by_id) + app.router.add_get("/streams/stats", stream_stats.get_stats) + app.router.add_get("/stream/{stream_id}/stats", stream_stats.get_stats_by_id) + + # Add hosted platform route prefix. + # NOTE: This ensures that the local and hosted experiences have consistent routes. + add_prefix_to_app_routes(app, "/live") web.run_app(app, host=args.host, port=int(args.port)) diff --git a/server/utils.py b/server/utils.py index 858f431d..7e6bc18c 100644 --- a/server/utils.py +++ b/server/utils.py @@ -54,6 +54,19 @@ async def create_datagram_endpoint( loop._patch_done = True +def add_prefix_to_app_routes(app: web.Application, prefix: str): + """Add a prefix to all routes in the given application. + + Args: + app: The web application whose routes will be prefixed. + prefix: The prefix to add to all routes. + """ + prefix = prefix.rstrip("/") + for route in list(app.router.routes()): + new_path = prefix + route.resource.canonical + app.router.add_route(route.method, new_path, route.handler) + + class StreamStats: """Class to get stream statistics.""" From fab3b64c7f4fe057b52ae467a9e813f195728ae4 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Fri, 7 Feb 2025 16:43:08 +0100 Subject: [PATCH 3/5] refactor: improve internal fps parameter naming This commit improves the naming of the parameters that are used in the fps calculation to ensure they are more descriptive. --- server/app.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/server/app.py b/server/app.py index 43f7ba3b..21389bd5 100644 --- a/server/app.py +++ b/server/app.py @@ -51,8 +51,8 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): super().__init__() self.track = track self.pipeline = pipeline - self._frame_count = 0 - self._start_time = time.monotonic() + self._fps_interval_frame_count = 0 + self._last_fps_calculation_time = time.monotonic() self._lock = threading.Lock() self._fps = 0.0 self._running = True @@ -60,8 +60,8 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): def _start_fps_thread(self): """Start a separate thread to calculate FPS periodically.""" - self.fps_thread = threading.Thread(target=self._calculate_fps_loop, daemon=True) - self.fps_thread.start() + self._fps_thread = threading.Thread(target=self._calculate_fps_loop, daemon=True) + self._fps_thread.start() def _calculate_fps_loop(self): """Loop to calculate FPS periodically.""" @@ -69,18 +69,18 @@ def _calculate_fps_loop(self): time.sleep(1) # Calculate FPS every second. with self._lock: current_time = time.monotonic() - time_diff = current_time - self._start_time + time_diff = current_time - self._last_fps_calculation_time if time_diff > 0: - self._fps = self._frame_count / time_diff + self._fps = self._fps_interval_frame_count / time_diff # Reset start_time and frame_count for the next interval. - self._start_time = current_time - self._frame_count = 0 + self._last_fps_calculation_time = current_time + self._fps_interval_frame_count = 0 def stop(self): """Stop the FPS calculation thread.""" self._running = False - self.fps_thread.join() + self._fps_thread.join() @property def fps(self) -> float: @@ -104,7 +104,7 @@ async def recv(self) -> av.VideoFrame: # Increment frame count for FPS calculation. with self._lock: - self._frame_count += 1 + self._fps_interval_frame_count += 1 return processed_frame From 85ebe5308a79a0f810df68c00707f8b4db57fdaf Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Fri, 7 Feb 2025 17:18:18 +0100 Subject: [PATCH 4/5] fix: ensure video stream is removed when stream ends This commit ensures that the video stream reference is removed from the app's `video_tracks` object when a stream ends, preventing potential memory leaks, incorrect data and ensuring proper cleanup. --- server/app.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/app.py b/server/app.py index 21389bd5..85e9111c 100644 --- a/server/app.py +++ b/server/app.py @@ -234,6 +234,7 @@ def on_track(track): @track.on("ended") async def on_ended(): logger.info(f"{track.kind} track ended") + request.app["video_tracks"].pop(track.id, None) @pc.on("connectionstatechange") async def on_connectionstatechange(): From 0b4a2978869e77d016898b805c2393fb8aa8399d Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Fri, 7 Feb 2025 16:38:57 +0100 Subject: [PATCH 5/5] feat: add 'frame_delay' to stats This commit adds the frame delay between the processing time and the expected presentation time to the `/stats` endpoints. --- server/app.py | 75 ++++++++++++++++++++++++++++++++++++++----------- server/utils.py | 1 + 2 files changed, 59 insertions(+), 17 deletions(-) diff --git a/server/app.py b/server/app.py index 85e9111c..71c3f69c 100644 --- a/server/app.py +++ b/server/app.py @@ -22,8 +22,9 @@ import time logger = logging.getLogger(__name__) -logging.getLogger('aiortc.rtcrtpsender').setLevel(logging.WARNING) -logging.getLogger('aiortc.rtcrtpreceiver').setLevel(logging.WARNING) +logging.getLogger("aiortc").setLevel(logging.DEBUG) +logging.getLogger("aiortc.rtcrtpsender").setLevel(logging.WARNING) +logging.getLogger("aiortc.rtcrtpreceiver").setLevel(logging.WARNING) MAX_BITRATE = 2000000 @@ -55,12 +56,21 @@ def __init__(self, track: MediaStreamTrack, pipeline: Pipeline): self._last_fps_calculation_time = time.monotonic() self._lock = threading.Lock() self._fps = 0.0 + self._frame_delay = 0.0 self._running = True + self._fps_interval_frame_count = 0 + self._last_fps_calculation_time = time.monotonic() + self._stream_start_time = None + self._last_frame_presentation_time = None + self._last_frame_processed_time = None self._start_fps_thread() + self._start_frame_delay_thread() def _start_fps_thread(self): """Start a separate thread to calculate FPS periodically.""" - self._fps_thread = threading.Thread(target=self._calculate_fps_loop, daemon=True) + self._fps_thread = threading.Thread( + target=self._calculate_fps_loop, daemon=True + ) self._fps_thread.start() def _calculate_fps_loop(self): @@ -77,10 +87,27 @@ def _calculate_fps_loop(self): self._last_fps_calculation_time = current_time self._fps_interval_frame_count = 0 + def _start_frame_delay_thread(self): + """Start a separate thread to calculate frame delay periodically.""" + self._frame_delay_thread = threading.Thread( + target=self._calculate_frame_delay_loop, daemon=True + ) + self._frame_delay_thread.start() + + def _calculate_frame_delay_loop(self): + """Loop to calculate frame delay periodically.""" + while self._running: + time.sleep(1) # Calculate frame delay every second. + with self._lock: + if self._last_frame_presentation_time is not None: + current_time = time.monotonic() + self._frame_delay = (current_time - self._stream_start_time ) - float(self._last_frame_presentation_time) + def stop(self): """Stop the FPS calculation thread.""" self._running = False self._fps_thread.join() + self._frame_delay_thread.join() @property def fps(self) -> float: @@ -92,6 +119,16 @@ def fps(self) -> float: with self._lock: return self._fps + @property + def frame_delay(self) -> float: + """Get the current frame delay. + + Returns: + The current frame delay. + """ + with self._lock: + return self._frame_delay + async def recv(self) -> av.VideoFrame: """Receive and process a video frame. Called by the WebRTC library when a frame is received. @@ -99,12 +136,17 @@ async def recv(self) -> av.VideoFrame: Returns: The processed video frame. """ + if self._stream_start_time is None: + self._stream_start_time = time.monotonic() + input_frame = await self.track.recv() processed_frame = await self.pipeline(input_frame) - # Increment frame count for FPS calculation. + # Store frame info for stats. with self._lock: self._fps_interval_frame_count += 1 + self._last_frame_presentation_time = input_frame.time + self._last_frame_processed_time = time.monotonic() return processed_frame @@ -187,30 +229,29 @@ async def offer(request): @pc.on("datachannel") def on_datachannel(channel): if channel.label == "control": + @channel.on("message") async def on_message(message): try: params = json.loads(message) - + if params.get("type") == "get_nodes": nodes_info = await pipeline.get_nodes_info() - response = { - "type": "nodes_info", - "nodes": nodes_info - } + response = {"type": "nodes_info", "nodes": nodes_info} channel.send(json.dumps(response)) elif params.get("type") == "update_prompt": if "prompt" not in params: - logger.warning("[Control] Missing prompt in update_prompt message") + logger.warning( + "[Control] Missing prompt in update_prompt message" + ) return pipeline.set_prompt(params["prompt"]) - response = { - "type": "prompt_updated", - "success": True - } + response = {"type": "prompt_updated", "success": True} channel.send(json.dumps(response)) else: - logger.warning("[Server] Invalid message format - missing required fields") + logger.warning( + "[Server] Invalid message format - missing required fields" + ) except json.JSONDecodeError: logger.error("[Server] Invalid JSON received") except Exception as e: @@ -310,8 +351,8 @@ async def on_shutdown(app: web.Application): logging.basicConfig( level=args.log_level.upper(), - format='%(asctime)s [%(levelname)s] %(message)s', - datefmt='%H:%M:%S' + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", ) app = web.Application() diff --git a/server/utils.py b/server/utils.py index 7e6bc18c..d31c0a92 100644 --- a/server/utils.py +++ b/server/utils.py @@ -85,6 +85,7 @@ def get_video_track_stats(self, video_track: MediaStreamTrack) -> Dict[str, Any] """ return { "fps": video_track.fps, + "frame_delay": video_track.frame_delay, } async def get_stats(self, _) -> web.Response: