diff --git a/dev-requirements.txt b/dev-requirements.txt index efc24f80..be256006 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -81,6 +81,8 @@ orjson==3.10.15 # via fastapi packaging==24.2 # via pytest +pillow==11.3.0 + # via labthings-fastapi (pyproject.toml) pluggy==1.5.0 # via pytest pydantic==2.10.6 diff --git a/pyproject.toml b/pyproject.toml index 0505259c..ffdb4ee6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dev = [ "mypy>=1.6.1, <2", "ruff>=0.1.3", "types-jsonschema", + "Pillow", ] [project.urls] diff --git a/src/labthings_fastapi/outputs/mjpeg_stream.py b/src/labthings_fastapi/outputs/mjpeg_stream.py index 3b7b88a8..9ffd9aeb 100644 --- a/src/labthings_fastapi/outputs/mjpeg_stream.py +++ b/src/labthings_fastapi/outputs/mjpeg_stream.py @@ -41,7 +41,11 @@ def __init__(self, gen: AsyncGenerator[bytes, None], status_code: int = 200): This response is initialised with an async generator that yields `bytes` objects, each of which is a JPEG file. We add the --frame markers and mime - types that enable it to work in an `img` tag. + types that mark it as an MJPEG stream. This is sufficient to enable it to + work in an `img` tag, with the `src` set to the MJPEG stream's endpoint. + + It expects an async generator that supplies individual JPEGs to be streamed, + such as the one provided by `.MJPEGStream`. NB the ``status_code`` argument is used by FastAPI to set the status code of the response in OpenAPI. @@ -63,6 +67,24 @@ async def mjpeg_async_generator(self) -> AsyncGenerator[bytes, None]: class MJPEGStream: + """Manage streaming images over HTTP as an MJPEG stream + + An MJPEGStream object handles accepting images (already in + JPEG format) and streaming them to HTTP clients as a multipart + response. + + The minimum needed to make the stream work is to periodically + call `add_frame` with JPEG image data. + + To add a stream to a `.Thing`, use the `.MJPEGStreamDescriptor` + which will handle creating an `MJPEGStream` object on first access, + and will also add it to the HTTP API. + + The MJPEG stream buffers the last few frames (10 by default) and + also has a hook to notify the size of each frame as it is added. + The latter is used by OpenFlexure's autofocus routine. + """ + def __init__(self, ringbuffer_size: int = 10): self._lock = threading.Lock() self.condition = anyio.Condition() @@ -85,10 +107,11 @@ def reset(self, ringbuffer_size: Optional[int] = None): ] self.last_frame_i = -1 - def stop(self): + def stop(self, portal: BlockingPortal): """Stop the stream""" with self._lock: self._streaming = False + portal.start_task_soon(self.notify_stream_stopped) async def ringbuffer_entry(self, i: int) -> RingbufferEntry: """Return the ith frame acquired by the camera @@ -117,9 +140,13 @@ async def buffer_for_reading(self, i: int) -> AsyncIterator[bytes]: yield entry.frame async def next_frame(self) -> int: - """Wait for the next frame, and return its index""" + """Wait for the next frame, and return its index + + :raises StopAsyncIteration: if the stream has stopped.""" async with self.condition: await self.condition.wait() + if not self._streaming: + raise StopAsyncIteration() return self.last_frame_i async def grab_frame(self) -> bytes: @@ -148,6 +175,8 @@ async def frame_async_generator(self) -> AsyncGenerator[bytes, None]: i = await self.next_frame() async with self.buffer_for_reading(i) as frame: yield frame + except StopAsyncIteration: + break except Exception as e: logging.error(f"Error in stream: {e}, stream stopped") return @@ -156,7 +185,7 @@ async def mjpeg_stream_response(self) -> MJPEGStreamResponse: """Return a StreamingResponse that streams an MJPEG stream""" return MJPEGStreamResponse(self.frame_async_generator()) - def add_frame(self, frame: bytes, portal: BlockingPortal): + def add_frame(self, frame: bytes, portal: BlockingPortal) -> None: """Return the next buffer in the ringbuffer to write to :param frame: The frame to add @@ -174,15 +203,31 @@ def add_frame(self, frame: bytes, portal: BlockingPortal): entry.index = self.last_frame_i + 1 portal.start_task_soon(self.notify_new_frame, entry.index) - async def notify_new_frame(self, i): - """Notify any waiting tasks that a new frame is available""" + async def notify_new_frame(self, i: int) -> None: + """Notify any waiting tasks that a new frame is available. + + :param i: The number of the frame (which counts up since the server starts) + """ async with self.condition: self.last_frame_i = i self.condition.notify_all() + async def notify_stream_stopped(self) -> None: + """Raise an exception in any waiting tasks to signal the stream has stopped.""" + assert self._streaming is False + async with self.condition: + self.condition.notify_all() + class MJPEGStreamDescriptor: - """A descriptor that returns a MJPEGStream object when accessed""" + """A descriptor that returns a MJPEGStream object when accessed + + If this descriptor is added to a `.Thing`, it will create an `.MJPEGStream` + object when it is first accessed. It will also add two HTTP endpoints, + one with the name of the descriptor serving the MJPEG stream, and another + with `/viewer` appended, which serves a basic HTML page that views the stream. + + """ def __init__(self, **kwargs): self._kwargs = kwargs diff --git a/tests/test_mjpeg_stream.py b/tests/test_mjpeg_stream.py new file mode 100644 index 00000000..b88df318 --- /dev/null +++ b/tests/test_mjpeg_stream.py @@ -0,0 +1,78 @@ +import io +import threading +import time +from PIL import Image +from fastapi.testclient import TestClient +import labthings_fastapi as lt + + +class Telly(lt.Thing): + _stream_thread: threading.Thread + _streaming: bool = False + framerate: float = 1000 + frame_limit: int = 3 + + stream = lt.outputs.MJPEGStreamDescriptor() + + def __enter__(self): + self._streaming = True + self._stream_thread = threading.Thread(target=self._make_images) + self._stream_thread.start() + + def __exit__(self, exc_t, exc_v, exc_tb): + self._streaming = False + self._stream_thread.join() + + def _make_images(self): + """Stream a series of solid colours""" + colours = ["#F00", "#0F0", "#00F"] + jpegs = [] + for c in colours: + image = Image.new("RGB", (10, 10), c) + dest = io.BytesIO() + image.save(dest, "jpeg") + jpegs.append(dest.getvalue()) + + i = 0 + while self._streaming and (i < self.frame_limit or self.frame_limit < 0): + self.stream.add_frame( + jpegs[i % len(jpegs)], self._labthings_blocking_portal + ) + time.sleep(1 / self.framerate) + i = i + 1 + self.stream.stop(self._labthings_blocking_portal) + self._streaming = False + + +def test_mjpeg_stream(): + """Verify the MJPEG stream contains at least one frame marker. + + A limitation of the TestClient is that it can't actually stream. + This means that all of the frames sent by our test Thing will + arrive in a single packet. + + For now, we just check it starts with the frame separator, + but it might be possible in the future to check there are three + images there. + """ + server = lt.ThingServer() + telly = Telly() + server.add_thing(telly, "telly") + with TestClient(server.app) as client: + with client.stream("GET", "/telly/stream") as stream: + stream.raise_for_status() + received = 0 + for b in stream.iter_bytes(): + received += 1 + assert b.startswith(b"--frame") + + +if __name__ == "__main__": + import uvicorn + + server = lt.ThingServer() + telly = Telly() + telly.framerate = 6 + telly.frame_limit = -1 + server.add_thing(telly, "telly") + uvicorn.run(server.app, port=5000)