diff --git a/docs/source/quickstart/counter.py b/docs/source/quickstart/counter.py index ea2990c5..21eff71d 100644 --- a/docs/source/quickstart/counter.py +++ b/docs/source/quickstart/counter.py @@ -22,8 +22,8 @@ def slowly_increase_counter(self) -> None: time.sleep(1) self.increment_counter() - counter = lt.ThingProperty( - model=int, initial_value=0, readonly=True, description="A pointless counter" + counter = lt.ThingProperty[int]( + initial_value=0, readonly=True, description="A pointless counter" ) diff --git a/examples/counter.py b/examples/counter.py index 95438422..a13a81ec 100644 --- a/examples/counter.py +++ b/examples/counter.py @@ -23,8 +23,8 @@ def slowly_increase_counter(self) -> None: time.sleep(1) self.increment_counter() - counter = lt.ThingProperty( - model=int, initial_value=0, readonly=True, description="A pointless counter" + counter = lt.ThingProperty[int]( + initial_value=0, readonly=True, description="A pointless counter" ) diff --git a/examples/demo_thing_server.py b/examples/demo_thing_server.py index 17166048..fc242b9a 100644 --- a/examples/demo_thing_server.py +++ b/examples/demo_thing_server.py @@ -58,12 +58,11 @@ def slowly_increase_counter(self): time.sleep(1) self.increment_counter() - counter = lt.ThingProperty( - model=int, initial_value=0, readonly=True, description="A pointless counter" + counter = lt.ThingProperty[int]( + initial_value=0, readonly=True, description="A pointless counter" ) - foo = lt.ThingProperty( - model=str, + foo = lt.ThingProperty[str]( initial_value="Example", description="A pointless string for demo purposes.", ) diff --git a/examples/opencv_camera_server.py b/examples/opencv_camera_server.py new file mode 100644 index 00000000..6db2a0af --- /dev/null +++ b/examples/opencv_camera_server.py @@ -0,0 +1,292 @@ +import logging +import threading + +from fastapi import FastAPI +from fastapi.responses import HTMLResponse, StreamingResponse +from labthings_fastapi.descriptors.property import ThingProperty +from labthings_fastapi.thing import Thing +from labthings_fastapi.decorators import thing_action, thing_property +from labthings_fastapi.server import ThingServer +from labthings_fastapi.file_manager import FileManagerDep +from typing import Optional, AsyncContextManager +from collections.abc import AsyncGenerator +from functools import partial +from dataclasses import dataclass +from datetime import datetime +from contextlib import asynccontextmanager +import anyio +from anyio.from_thread import BlockingPortal +from threading import RLock +import cv2 as cv + +logging.basicConfig(level=logging.INFO) + + +@dataclass +class RingbufferEntry: + """A single entry in a ringbuffer""" + + frame: bytes + timestamp: datetime + index: int + readers: int = 0 + + +class MJPEGStreamResponse(StreamingResponse): + media_type = "multipart/x-mixed-replace; boundary=frame" + + def __init__(self, gen: AsyncGenerator[bytes, None], status_code: int = 200): + """A StreamingResponse that streams an MJPEG stream + + This response is initialised with an async generator that yields `bytes` + objects, each of which is a JPEG file. We add the --frame markers and mime + types that enable it to work in an `img` tag. + + NB the `status_code` argument is used by FastAPI to set the status code of + the response in OpenAPI. + """ + self.frame_async_generator = gen + StreamingResponse.__init__( + self, + self.mjpeg_async_generator(), + media_type=self.media_type, + status_code=status_code, + ) + + async def mjpeg_async_generator(self) -> AsyncGenerator[bytes, None]: + """A generator yielding an MJPEG stream""" + async for frame in self.frame_async_generator: + yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + yield frame + yield b"\r\n" + + +class MJPEGStream: + def __init__(self, ringbuffer_size: int = 10): + self._lock = threading.Lock() + self.condition = anyio.Condition() + self._streaming = False + self.reset(ringbuffer_size=ringbuffer_size) + + def reset(self, ringbuffer_size: Optional[int] = None): + """Reset the stream and optionally change the ringbuffer size""" + with self._lock: + self._streaming = True + n = ringbuffer_size or len(self._ringbuffer) + self._ringbuffer = [ + RingbufferEntry( + frame=b"", + index=-1, + timestamp=datetime.min, + ) + for i in range(n) + ] + self.last_frame_i = -1 + + def stop(self): + """Stop the stream""" + with self._lock: + self._streaming = False + + async def ringbuffer_entry(self, i: int) -> RingbufferEntry: + """Return the `i`th frame acquired by the camera""" + if i < 0: + raise ValueError("i must be >= 0") + if i < self.last_frame_i - len(self._ringbuffer) + 2: + raise ValueError("the ith frame has been overwritten") + if i > self.last_frame_i: + # TODO: await the ith frame + raise ValueError("the ith frame has not yet been acquired") + entry = self._ringbuffer[i % len(self._ringbuffer)] + if entry.index != i: + raise ValueError("the ith frame has been overwritten") + return entry + + @asynccontextmanager + async def buffer_for_reading(self, i: int) -> AsyncContextManager[bytes]: + """Yields the ith frame as a bytes object""" + entry = await self.ringbuffer_entry(i) + try: + entry.readers += 1 + yield entry.frame + finally: + entry.readers -= 1 + + async def next_frame(self) -> int: + """Wait for the next frame, and return its index""" + async with self.condition: + await self.condition.wait() + return self.last_frame_i + + async def frame_async_generator(self) -> AsyncGenerator[bytes, None]: + """A generator that yields frames as bytes""" + while self._streaming: + try: + i = await self.next_frame() + async with self.buffer_for_reading(i) as frame: + yield frame + except Exception as e: + logging.error(f"Error in stream: {e}, stream stopped") + return + + async def mjpeg_stream_response(self) -> MJPEGStreamResponse: + """Return a StreamingResponse that streams an MJPEG stream""" + return MJPEGStreamResponse(self.frame_async_generator()) + + def add_frame(self, frame: bytes, portal: BlockingPortal): + """Return the next buffer in the ringbuffer to write to""" + with self._lock: + entry = self._ringbuffer[(self.last_frame_i + 1) % len(self._ringbuffer)] + if entry.readers > 0: + raise RuntimeError("Cannot write to ringbuffer while it is being read") + entry.timestamp = datetime.now() + entry.frame = frame + entry.index = self.last_frame_i + 1 + portal.start_task_soon(self.notify_new_frame, entry.index) + + async def notify_new_frame(self, i): + """Notify any waiting tasks that a new frame is available""" + async with self.condition: + self.last_frame_i = i + self.condition.notify_all() + + +class MJPEGStreamDescriptor: + """A descriptor that returns a MJPEGStream object when accessed""" + + def __init__(self, **kwargs): + self._kwargs = kwargs + + def __set_name__(self, owner, name): + self.name = name + + def __get__(self, obj, type=None) -> MJPEGStream: + """The value of the property + + If `obj` is none (i.e. we are getting the attribute of the class), + we return the descriptor. + + If no getter is set, we'll return either the initial value, or the value + from the object's __dict__, i.e. we behave like a variable. + + If a getter is set, we will use it, unless the property is observable, at + which point the getter is only ever used once, to set the initial value. + """ + if obj is None: + return self + try: + return obj.__dict__[self.name] + except KeyError: + obj.__dict__[self.name] = MJPEGStream(**self._kwargs) + return obj.__dict__[self.name] + + async def viewer_page(self, url: str) -> HTMLResponse: + return HTMLResponse(f"") + + def add_to_fastapi(self, app: FastAPI, thing: Thing): + """Add the stream to the FastAPI app""" + app.get( + f"{thing.path}{self.name}", + response_class=MJPEGStreamResponse, + )(self.__get__(thing).mjpeg_stream_response) + app.get( + f"{thing.path}{self.name}/viewer", + response_class=HTMLResponse, + )(partial(self.viewer_page, f"{thing.path}{self.name}")) + + +class OpenCVCamera(Thing): + """A Thing that represents an OpenCV camera""" + + def __init__(self, device_index: int = 0): + self.device_index = device_index + self._stream_thread: Optional[threading.Thread] = None + + def __enter__(self): + self._cap = cv.VideoCapture(self.device_index) + self._cap_lock = RLock() + if not self._cap.isOpened(): + raise IOError(f"Cannot open camera with device index {self.device_index}") + self.start_streaming() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stop_streaming() + self._cap.release() + del self._cap + del self._cap_lock + + def start_streaming(self): + print("starting stream...") + if self._stream_thread is not None: + raise RuntimeError("Stream thread already running") + self._stream_thread = threading.Thread(target=self._stream_thread_fn) + self._continue_streaming = True + self._stream_thread.start() + print("started") + + def stop_streaming(self): + print("stopping stream...") + if self._stream_thread is None: + raise RuntimeError("Stream thread not running") + self._continue_streaming = False + self.mjpeg_stream.stop() + print("waiting for stream to join") + self._stream_thread.join() + print("stream stopped.") + self._stream_thread = None + + def _stream_thread_fn(self): + while self._continue_streaming: + with self._cap_lock: + ret, frame = self._cap.read() + if not ret: + logging.error("Could not read frame from camera") + continue + success, array = cv.imencode(".jpg", frame) + if success: + self.mjpeg_stream.add_frame( + frame=array.tobytes(), + portal=self._labthings_blocking_portal, + ) + self.last_frame_index = self.mjpeg_stream.last_frame_i + + @thing_action + def snap_image(self, file_manager: FileManagerDep) -> str: + """Acquire one image from the camera. + + This action cannot run if the camera is in use by a background thread, for + example if a preview stream is running. + """ + with self._cap_lock: + ret, frame = self._cap.read() + if not ret: + raise IOError("Could not read image from camera") + fpath = file_manager.path("image.jpg", rel="image") + cv.imwrite(fpath, frame) + return ( + "image.jpg is available from the links property of this Invocation " + "(see ./files)" + ) + + @thing_property + def exposure(self) -> float: + with self._cap_lock: + return self._cap.get(cv.CAP_PROP_EXPOSURE) + + @exposure.setter + def exposure(self, value): + with self._cap_lock: + self._cap.set(cv.CAP_PROP_EXPOSURE, value) + + last_frame_index = ThingProperty[int](int, initial_value=-1) + + mjpeg_stream = MJPEGStreamDescriptor(ringbuffer_size=10) + + +thing_server = ThingServer() +my_thing = OpenCVCamera() +my_thing.validate_thing_description() +thing_server.add_thing(my_thing, "/camera") + +app = thing_server.app diff --git a/examples/picamera2_camera_server.py b/examples/picamera2_camera_server.py new file mode 100644 index 00000000..e9029e84 --- /dev/null +++ b/examples/picamera2_camera_server.py @@ -0,0 +1,227 @@ +from __future__ import annotations +import logging +import time + +from pydantic import BaseModel, BeforeValidator + +from labthings_fastapi.descriptors.property import ThingProperty +from labthings_fastapi.thing import Thing +from labthings_fastapi.decorators import thing_action, thing_property +from labthings_fastapi.server import ThingServer +from labthings_fastapi.file_manager import FileManagerDep +from typing import Annotated, Any, Iterator, Optional +from contextlib import contextmanager +from anyio.from_thread import BlockingPortal +from threading import RLock +import picamera2 +from picamera2 import Picamera2 +from picamera2.encoders import MJPEGEncoder, Quality +from picamera2.outputs import Output +from labthings_fastapi.outputs.mjpeg_stream import MJPEGStreamDescriptor, MJPEGStream +from labthings_fastapi.utilities import get_blocking_portal + + +logging.basicConfig(level=logging.INFO) + + +class PicameraControl(ThingProperty): + def __init__( + self, control_name: str, model: type = float, description: Optional[str] = None + ): + """A property descriptor controlling a picamera control""" + ThingProperty.__init__(self, model, observable=False, description=description) + self.control_name = control_name + self._getter + + def _getter(self, obj: StreamingPiCamera2): + print(f"getting {self.control_name} from {obj}") + with obj.picamera() as cam: + ret = cam.capture_metadata()[self.control_name] + print(f"Trying to return camera control {self.control_name} as `{ret}`") + return ret + + def _setter(self, obj: StreamingPiCamera2, value: Any): + with obj.picamera() as cam: + setattr(cam.controls, self.control_name, value) + + +class PicameraStreamOutput(Output): + """An Output class that sends frames to a stream""" + + def __init__(self, stream: MJPEGStream, portal: BlockingPortal): + """Create an output that puts frames in an MJPEGStream + + We need to pass the stream object, and also the blocking portal, because + new frame notifications happen in the anyio event loop and frames are + sent from a thread. The blocking portal enables thread-to-async + communication. + """ + Output.__init__(self) + self.stream = stream + self.portal = portal + + def outputframe(self, frame, _keyframe=True, _timestamp=None): + """Add a frame to the stream's ringbuffer""" + self.stream.add_frame(frame, self.portal) + + +class SensorMode(BaseModel): + unpacked: str + bit_depth: int + size: tuple[int, int] + fps: float + crop_limits: tuple[int, int, int, int] + exposure_limits: tuple[Optional[int], Optional[int], Optional[int]] + format: Annotated[str, BeforeValidator(repr)] + + +class StreamingPiCamera2(Thing): + """A Thing that represents an OpenCV camera""" + + def __init__(self, device_index: int = 0): + self.device_index = device_index + self.camera_configs: dict[str, dict] = {} + + stream_resolution = ThingProperty[tuple[int, int]]( + initial_value=(1640, 1232), + description="Resolution to use for the MJPEG stream", + ) + image_resolution = ThingProperty[tuple[int, int]]( + initial_value=(3280, 2464), + description="Resolution to use for still images (by default)", + ) + mjpeg_bitrate = ThingProperty[int]( + initial_value=0, description="Bitrate for MJPEG stream (best left at 0)" + ) + stream_active = ThingProperty[bool]( + initial_value=False, + description="Whether the MJPEG stream is active", + observable=True, + ) + mjpeg_stream = MJPEGStreamDescriptor() + analogue_gain = PicameraControl("AnalogueGain", float) + colour_gains = PicameraControl("ColourGains", tuple[float, float]) + colour_correction_matrix = PicameraControl( + "ColourCorrectionMatrix", + tuple[float, float, float, float, float, float, float, float, float], + ) + exposure_time = PicameraControl( + "ExposureTime", int, description="The exposure time in microseconds" + ) + exposure_time = PicameraControl( + "ExposureTime", int, description="The exposure time in microseconds" + ) + sensor_modes = ThingProperty[list[SensorMode]](readonly=True, getter=list) + + def __enter__(self): + self._picamera = picamera2.Picamera2(camera_num=self.device_index) + self._picamera_lock = RLock() + self.populate_sensor_modes() + self.start_streaming() + return self + + @contextmanager + def picamera(self) -> Iterator[Picamera2]: + with self._picamera_lock: + yield self._picamera + + def populate_sensor_modes(self): + with self.picamera() as cam: + self.sensor_modes = cam.sensor_modes + + def __exit__(self, exc_type, exc_value, traceback): + self.stop_streaming() + with self.picamera() as cam: + cam.close() + del self._picamera + + def start_streaming(self) -> None: + """ + Start the MJPEG stream + + Sets the camera resolution to the video/stream resolution, and starts recording + if the stream should be active. + """ + with self.picamera() as picam: + # TODO: Filip: can we use the lores output to keep preview stream going + # while recording? According to picamera2 docs 4.2.1.6 this should work + try: + if picam.started: + picam.stop() + if picam.encoder is not None and picam.encoder.running: + picam.encoder.stop() + stream_config = picam.create_video_configuration( + main={"size": self.stream_resolution}, + # colour_space=ColorSpace.Rec709(), + ) + picam.configure(stream_config) + logging.info("Starting picamera MJPEG stream...") + picam.start_recording( + MJPEGEncoder( + self.mjpeg_bitrate if self.mjpeg_bitrate > 0 else None, + ), + PicameraStreamOutput( + self.mjpeg_stream, + get_blocking_portal(self), + ), + Quality.HIGH, # TODO: use provided quality + ) + except Exception as e: + logging.info("Error while starting preview:") + logging.exception(e) + else: + self.stream_active = True + logging.debug( + "Started MJPEG stream at %s on port %s", self.stream_resolution, 1 + ) + + def stop_streaming(self) -> None: + """ + Stop the MJPEG stream + """ + with self.picamera() as picam: + try: + picam.stop_recording() + except Exception as e: + logging.info("Stopping recording failed") + logging.exception(e) + else: + self.stream_active = False + self.mjpeg_stream.stop() + logging.info( + f"Stopped MJPEG stream. Switching to {self.image_resolution}." + ) + + # Increase the resolution for taking an image + time.sleep( + 0.2 + ) # Sprinkled a sleep to prevent camera getting confused by rapid commands + + @thing_action + def snap_image(self, file_manager: FileManagerDep) -> str: + """Acquire one image from the camera. + + This action cannot run if the camera is in use by a background thread, for + example if a preview stream is running. + """ + raise NotImplementedError + + @thing_property + def exposure(self) -> float: + raise NotImplementedError() + + @exposure.setter + def exposure(self, value): + raise NotImplementedError() + + last_frame_index = [int](initial_value=-1) + + mjpeg_stream = MJPEGStreamDescriptor(ringbuffer_size=10) + + +thing_server = ThingServer() +my_thing = StreamingPiCamera2() +my_thing.validate_thing_description() +thing_server.add_thing(my_thing, "/camera") + +app = thing_server.app diff --git a/src/labthings_fastapi/decorators/__init__.py b/src/labthings_fastapi/decorators/__init__.py index 8ae57991..b4c98b32 100644 --- a/src/labthings_fastapi/decorators/__init__.py +++ b/src/labthings_fastapi/decorators/__init__.py @@ -33,7 +33,7 @@ """ from functools import wraps, partial -from typing import Optional, Callable +from typing import Optional, Callable, TypeVar from ..descriptors import ( ActionDescriptor, ThingProperty, @@ -72,7 +72,10 @@ def thing_action(func: Optional[Callable] = None, **kwargs): return partial(mark_thing_action, **kwargs) -def thing_property(func: Callable) -> ThingProperty: +Value = TypeVar("Value") + + +def thing_property(func: Callable[..., Value]) -> ThingProperty[Value]: """Mark a method of a Thing as a LabThings Property This should be used as a decorator with a getter and a setter @@ -92,7 +95,7 @@ def thing_property(func: Callable) -> ThingProperty: ) -def thing_setting(func: Callable) -> ThingSetting: +def thing_setting(func: Callable[..., Value]) -> ThingSetting[Value]: """Mark a method of a Thing as a LabThings Setting. A setting is a property that persists between runs. diff --git a/src/labthings_fastapi/descriptors/property.py b/src/labthings_fastapi/descriptors/property.py index f51a31cf..78d8c717 100644 --- a/src/labthings_fastapi/descriptors/property.py +++ b/src/labthings_fastapi/descriptors/property.py @@ -3,7 +3,20 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Annotated, Any, Callable, Optional +from types import EllipsisType +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Optional, + Generic, + Type, + TypeAlias, + TypeVar, + overload, +) +import typing from weakref import WeakSet from typing_extensions import Self @@ -21,32 +34,131 @@ from ..thing import Thing -class ThingProperty: +class MissingTypeError(TypeError): + """Error raised when a type annotation is missing for a property.""" + + +class MismatchedTypeError(TypeError): + """Error raised when a type annotation does not match the expected type for a property.""" + + +class MissingDefaultError(AttributeError): + """Error raised when a property has no getter or initial value.""" + + +Value = TypeVar("Value") +Owner: TypeAlias = "Thing" +# There was an intention to make ThingProperty generic in 2 variables, one for +# the value and one for the owner, but this was problematic. +# For now, we'll stick to typing the owner as a Thing. +# We may want to search-and-replace the Owner symbol, but I think it is +# helpful for now. I don't think NewType would be appropriate here, +# as it would probably raise errors when defining getter/setter methods. + + +class ThingProperty(Generic[Value]): """A property that can be accessed via the HTTP API By default, a ThingProperty is "dumb", i.e. it acts just like - a normal variable. + a normal variable. It can have a getter and setter, in which case + it will work similarly to a Python property. """ model: type[BaseModel] + """A Pydantic model that describes the type of the property.""" readonly: bool = False + """If True, the property cannot be set via the HTTP API""" + _model_arg: type[Value] + """The type of the model argument, if specified.""" + _value_type: type[Value] + """The type of the value, may or may not be a Pydantic model.""" def __init__( self, - model: type, - initial_value: Any = None, + model: type | None = None, + initial_value: Value | EllipsisType = ..., readonly: bool = False, observable: bool = False, description: Optional[str] = None, title: Optional[str] = None, - getter: Optional[Callable] = None, - setter: Optional[Callable] = None, + getter: Optional[ + Callable[ + [ + Owner, + ], + Value, + ] + ] = None, + setter: Optional[Callable[[Owner, Value], None]] = None, ): - if getter and initial_value is not None: + """A property that can be accessed via the HTTP API + + ThingProperty is a descriptor that functions like a variable, optionally + with notifications when it is set. It may also have a getter and setter, + which work in a similar way to Python properties. + + The type of a property can be set in several ways: + 1. As a type argument on the property itself, e.g. `ThingProperty[int]` + 2. As a type annotation on the class, e.g. `my_property: int = ThingProperty` + 3. As a type annotation on the getter method, e.g. + `@ThingProperty\n def my_property(self) -> int: ...` + 4. As an explicitly set model argument, e.g. `ThingProperty(model=int)` + + All of these are checked, and an error is raised if any of them are inconsistent. + If no type is specified, an error is raised. `model` may be deprecated in the + future. + + ``ThingProperty`` can behave in several different ways: + - If no `getter` or `setter` is specified, it will behave like a simple + data attribute (i.e. a variable). If `observable` is `True`, it is + possible to register for notifications when the value is set. In this + case, an `initial_value` is required. + - If a `getter` is specified and `observable` is `False`, the `getter` + will be called when the property is accessed, and its return value + will be the property's value, just like the builtin `property`. The + property will be read-only both locally and via HTTP. + - If a `getter` is specified and `observable` is `True`, the `getter` + is used instead of `initial_value` but thereafter the property + behaves like a variable. The `getter` is only on first access. + The property may be written to locally, and whether it's writable + via HTTP depends on the `readonly` argument. + - If both a `getter` and `setter` are specified and `observable` is `False`, + the property behaves like a Python property, with the `getter` being + called when the property is accessed, and the `setter` being called + when the property is set. The property is read-only via HTTP if + `readonly` is `True`. It may always be written to locally. + - If `observable` is `True` and a `setter` is specified, the property + will behave like a variable, but will call the `setter` + when the property is set. The `setter` may perform tasks like sending + the updated value to the hardware, but it is not responsible for + remembering the value. The initial value is set via the `getter` or + `initial_value`. + + + :param model: The type of the property. This is optional, because it is + better to use type hints (see notes on typing above). + :param initial_value: The initial value of the property. If this is set, + the property must not have a getter, and should behave like a variable. + :param readonly: If True, the property cannot be set via the HTTP API. + :param observable: If True, the property can be observed for changes via + websockets. This causes the setter to run code in the async event loop + that will notify a list of subscribers each time the property is set. + Currently, only websockets can be used to observe properties. + :param description: A description of the property, used in the API documentation. + LabThings will attempt to take this from the docstring if not supplied. + :param title: A human-readable title for the property, used in the API + documentation. Defaults to the first line of the docstring, or the name + of the property. + :param getter: A function that gets the value of the property. + :param setter: A function that sets the value of the property. + """ + if getter and not isinstance(initial_value, EllipsisType): raise ValueError("getter and an initial value are mutually exclusive.") - if model is None: - raise ValueError("LabThings Properties must have a type") - self.model = wrap_plain_types_in_rootmodel(model) + if isinstance(initial_value, EllipsisType) and getter is None: + raise MissingDefaultError() + # We no longer check types in __init__, as we do that in __set_name__ + if isinstance(model, type): + self._model_arg = model self.readonly = readonly self.observable = observable self.initial_value = initial_value @@ -57,10 +169,82 @@ def __init__( self._getter = getter or getattr(self, "_getter", None) # Try to generate a DataSchema, so that we can raise an error that's easy to # link to the offending ThingProperty - type_to_dataschema(self.model) - def __set_name__(self, owner, name: str): + def __set_name__(self, owner: type[Owner], name: str) -> None: + """Notification of the name and owning class. + + When a descriptor is attached to a class, Python calls this method. + We use it to take note of the property's name and the class it belongs to, + which also allows us to check if there is a type annotation for the property + on the class. + + The type of a property can be set in several ways: + 1. As a type argument on the property itself, e.g. `BaseThingProperty[int]` + 2. As a type annotation on the class, e.g. `my_property: int = BaseThingProperty` + 3. As a type annotation on the getter method, e.g. `@BaseThingProperty\n def my_property(self) -> int: ...` + + There is a model argument, e.g. `BaseThingProperty(model=int)` but this is no longer + supported and will raise an error. + + All of these are checked, and an error is raised if any of them are inconsistent. + If no type is specified, an error is raised. + + This method is called after `__init__`, so if there was a type subscript + (e.g. `BaseThingProperty[ModelType]`), it will be available as + `self.__orig_class__` at this point (but not during `__init__`). + + :param owner: The class that owns this property. + :param name: The name of the property. + + :raises MissingTypeError: If no type annotation is found for the property. + :raises MismatchedTypeError: If multiple type annotations are found and they do not agree. + """ self._name = name + value_types: dict[str, type[Value]] = {} + if hasattr(self, "_model_arg"): + # If we have a model argument, we can use that + value_types["model_argument"] = self._model_arg + if self._getter is not None: + # If the property has a getter, we can extract the type from it + annotations = typing.get_type_hints(self._getter, include_extras=True) + if "return" in annotations: + value_types["getter_return_type"] = annotations["return"] + owner_annotations = typing.get_type_hints(owner, include_extras=True) + if name in owner_annotations: + # If the property has a type annotation on the owning class, we can use that + value_types["class_annotation"] = owner_annotations[name] + if hasattr(self, "__orig_class__"): + # We were instantiated as BaseThingProperty[ModelType] so can use that type + value_types["__orig_class__"] = typing.get_args(self.__orig_class__)[0] + + # Check we have a model, and that it is consistent if it's specified in multiple places + try: + # Pick the first one we find, then check the rest against it + self._value_type = next(iter(value_types.values())) + for v_type in value_types.values(): + if v_type != self._value_type: + raise MismatchedTypeError( + f"Inconsistent model for property '{name}' on '{owner}'. " + f"Types were: {value_types}." + ) + except StopIteration: # This means no types were found, value_types is empty + raise MissingTypeError( + f"Property '{name}' on '{owner}' is missing a type annotation. " + "Please provide a type annotation ." + ) + if len(value_types) == 1 and "model_argument" in value_types: + raise MissingTypeError( + f"Property '{name}' on '{owner}' specifies `model` but is not type annotated." + ) + print( + f"Initializing property '{name}' on '{owner}', {value_types}." + ) # TODO: Debug print statement, remove + # If the model is a plain type, wrap it in a RootModel so that it can be used + # as a FastAPI model. + self.model = wrap_plain_types_in_rootmodel(self._value_type) + # Try to generate a DataSchema, so that we can raise an error that's easy to + # link to the offending ThingProperty + type_to_dataschema(self.model) @property def title(self): @@ -76,7 +260,17 @@ def description(self): """A description of the property""" return self._description or get_docstring(self._getter, remove_summary=True) - def __get__(self, obj, type=None) -> Any: + @overload + def __get__(self, obj: None, owner: Type[Owner]) -> Self: + """Called when an attribute is accessed via class not an instance""" + + @overload + def __get__(self, obj: Owner, owner: Type[Owner] | None) -> Value: + """Called when an attribute is accessed on an instance variable""" + + def __get__( + self, obj: Owner | None, owner: Type[Owner] | None = None + ) -> Value | Self: """The value of the property If `obj` is none (i.e. we are getting the attribute of the class), @@ -101,8 +295,13 @@ def __get__(self, obj, type=None) -> Any: # if we get to here, the property should be observable, so cache obj.__dict__[self.name] = self._getter(obj) return obj.__dict__[self.name] - else: + elif not isinstance(self.initial_value, EllipsisType): return self.initial_value + else: + raise MissingDefaultError( + f"Property '{self.name}' on '{obj.__class__.__name__}' has " + " no value and no getter or initial value." + ) def __set__(self, obj, value): """Set the property's value""" @@ -187,7 +386,7 @@ def set_property(body): # We'll annotate body later description=f"## {self.title}\n\n{self.description or ''}", ) def get_property(): - return self.__get__(thing) + return self.__get__(thing, type(thing)) def property_affordance( self, thing: Thing, path: Optional[str] = None @@ -238,7 +437,7 @@ def setter(self, func: Callable) -> Self: return self -class ThingSetting(ThingProperty): +class ThingSetting(ThingProperty[Value], Generic[Value]): """A setting can be accessed via the HTTP API and is persistent between sessions A ThingSetting is a ThingProperty with extra functionality for triggering diff --git a/src/labthings_fastapi/example_things/__init__.py b/src/labthings_fastapi/example_things/__init__.py index 61d4d7ff..9f42d71b 100644 --- a/src/labthings_fastapi/example_things/__init__.py +++ b/src/labthings_fastapi/example_things/__init__.py @@ -73,12 +73,11 @@ def slowly_increase_counter(self, increments: int = 60, delay: float = 1): time.sleep(delay) self.increment_counter() - counter = ThingProperty( - model=int, initial_value=0, readonly=True, description="A pointless counter" + counter = ThingProperty[int]( + initial_value=0, readonly=True, description="A pointless counter" ) - foo = ThingProperty( - model=str, + foo = ThingProperty[str]( initial_value="Example", description="A pointless string for demo purposes.", ) @@ -103,7 +102,7 @@ def broken_action(self): raise RuntimeError("This is a broken action") @thing_property - def broken_property(self): + def broken_property(self) -> bool: """A property that raises an exception""" raise RuntimeError("This is a broken property") diff --git a/tests/test_action_cancel.py b/tests/test_action_cancel.py index a5a4ad49..f9bcf705 100644 --- a/tests/test_action_cancel.py +++ b/tests/test_action_cancel.py @@ -10,10 +10,9 @@ class CancellableCountingThing(lt.Thing): - counter = lt.ThingProperty(int, 0, observable=False) - check = lt.ThingProperty( - bool, - False, + counter = lt.ThingProperty[int](initial_value=0, observable=False) + check = lt.ThingProperty[bool]( + initial_value=False, observable=False, description=( "This variable is used to check that the action can detect a cancel event " diff --git a/tests/test_action_manager.py b/tests/test_action_manager.py index 4ae2ab4f..430b12ab 100644 --- a/tests/test_action_manager.py +++ b/tests/test_action_manager.py @@ -13,7 +13,7 @@ def increment_counter(self): """Increment the counter""" self.counter += 1 - counter = lt.ThingProperty( + counter = lt.ThingProperty[int]( model=int, initial_value=0, readonly=True, description="A pointless counter" ) diff --git a/tests/test_dependency_metadata.py b/tests/test_dependency_metadata.py index 2536d8f2..cd329838 100644 --- a/tests/test_dependency_metadata.py +++ b/tests/test_dependency_metadata.py @@ -14,7 +14,7 @@ def __init__(self): self._a = 0 @lt.thing_property - def a(self): + def a(self) -> int: return self._a @a.setter diff --git a/tests/test_properties.py b/tests/test_properties.py index 3946c7e1..70f2fb62 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -3,19 +3,29 @@ from pytest import raises from pydantic import BaseModel from fastapi.testclient import TestClient +import pytest import labthings_fastapi as lt from labthings_fastapi.exceptions import NotConnectedToServerError +from labthings_fastapi.descriptors.property import ( + MismatchedTypeError, + MissingTypeError, + MissingDefaultError, +) class TestThing(lt.Thing): - boolprop = lt.ThingProperty(bool, False, description="A boolean property") - stringprop = lt.ThingProperty(str, "foo", description="A string property") + boolprop = lt.ThingProperty[bool]( + initial_value=False, description="A boolean property" + ) + stringprop = lt.ThingProperty[str]( + initial_value="foo", description="A string property" + ) _undoc = None @lt.thing_property - def undoc(self): + def undoc(self) -> None: return self._undoc _float = 1.0 @@ -48,10 +58,68 @@ def test_instantiation_with_type(): Check the internal model (data type) of the ThingSetting descriptor is a BaseModel To send the data over HTTP LabThings-FastAPI uses Pydantic models to describe data - types. + types. Note that the model is not created until the property is assigned to a + `Thing`, as it happens in `__set_name__` of the `ThingProperty` descriptor. """ - prop = lt.ThingProperty(bool, False) - assert issubclass(prop.model, BaseModel) + + class BasicThing(lt.Thing): + prop = lt.ThingProperty[bool](initial_value=False) + + assert issubclass(BasicThing.prop.model, BaseModel) + + +def exception_is_or_is_caused_by(err: Exception, cls: type[Exception]): + return isinstance(err, cls) or isinstance(err.__cause__, cls) + + +def test_instantiation_with_type_and_model(): + """If a model is specified, we check it matches the inferred type.""" + + class BasicThing(lt.Thing): + prop = lt.ThingProperty[bool](model=bool, initial_value=False) + + with pytest.raises(Exception) as e: + + class InvalidThing(lt.Thing): + prop = lt.ThingProperty[bool](model=int, initial_value=False) + + assert exception_is_or_is_caused_by(e.value, MismatchedTypeError) + + with pytest.raises(Exception) as e: + + class InvalidThing(lt.Thing): + prop = lt.ThingProperty(model=bool, initial_value=False) + + assert exception_is_or_is_caused_by(e.value, MissingTypeError) + + +def test_missing_default(): + """Test that a default is required if no model is specified.""" + with pytest.raises(MissingDefaultError): + + class InvalidThing(lt.Thing): + prop = lt.ThingProperty[bool]() + + +def test_annotation_on_class(): + """Test that a type annotation on the attribute is picked up.""" + + class BasicThing(lt.Thing): + prop: bool = lt.ThingProperty(initial_value=False) + + assert isinstance(BasicThing.prop, lt.ThingProperty) + assert BasicThing.prop._value_type is bool + + +def test_overspecified_default(): + """Test that a default is not allowed if a getter is specified.""" + with pytest.raises(ValueError): + + class InvalidThing(lt.Thing): + def get_prop(self) -> bool: + return False + + prop = lt.ThingProperty[bool](initial_value=False, getter=get_prop) def test_instantiation_with_model(): @@ -59,8 +127,10 @@ class MyModel(BaseModel): a: int = 1 b: float = 2.0 - prop = lt.ThingProperty(MyModel, MyModel()) - assert prop.model is MyModel + class BasicThing(lt.Thing): + prop = lt.ThingProperty[MyModel](initial_value=MyModel()) + + assert BasicThing.prop.model is MyModel def test_property_get_and_set(): diff --git a/tests/test_settings.py b/tests/test_settings.py index 50bb7656..c408d4e3 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -11,9 +11,9 @@ class TestThing(lt.Thing): - boolsetting = lt.ThingSetting(bool, False, description="A boolean setting") - stringsetting = lt.ThingSetting(str, "foo", description="A string setting") - dictsetting = lt.ThingSetting( + boolsetting = lt.ThingSetting[bool](bool, False, description="A boolean setting") + stringsetting = lt.ThingSetting[str](str, "foo", description="A string setting") + dictsetting = lt.ThingSetting[dict]( dict, {"a": 1, "b": 2}, description="A dictionary setting" ) diff --git a/tests/test_thing_lifecycle.py b/tests/test_thing_lifecycle.py index 2d7331b9..8e7915f0 100644 --- a/tests/test_thing_lifecycle.py +++ b/tests/test_thing_lifecycle.py @@ -3,7 +3,9 @@ class TestThing(lt.Thing): - alive = lt.ThingProperty(bool, False, description="Is the thing alive?") + alive = lt.ThingProperty[bool]( + initial_value=False, description="Is the thing alive?" + ) def __enter__(self): print("setting up TestThing from __enter__")