Skip to content

Commit 75bf89d

Browse files
Experimental fix for numpy typing. Works only for returns
1 parent d9c68e3 commit 75bf89d

File tree

5 files changed

+43
-9
lines changed

5 files changed

+43
-9
lines changed

src/labthings_fastapi/client/__init__.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import httpx
1414
from urllib.parse import urlparse, urljoin
1515

16+
import numpy as np
1617
from pydantic import BaseModel
1718

1819
from .outputs import ClientBlobOutput
@@ -159,7 +160,9 @@ def set_property(self, path: str, value: Any) -> None:
159160
r = self.client.put(urljoin(self.path, path), json=value)
160161
r.raise_for_status()
161162

162-
def invoke_action(self, path: str, **kwargs: Any) -> Any:
163+
def invoke_action(
164+
self, path: str, labthings_typehint: str | None, **kwargs: Any
165+
) -> Any:
163166
r"""Invoke an action on the Thing.
164167
165168
This method will make the initial POST request to invoke an action,
@@ -205,7 +208,7 @@ def invoke_action(self, path: str, **kwargs: Any) -> Any:
205208
href=invocation["output"]["href"],
206209
client=self.client,
207210
)
208-
return invocation["output"]
211+
return _adjust_type(invocation["output"], labthings_typehint)
209212
else:
210213
raise RuntimeError(f"Action did not complete successfully: {invocation}")
211214

@@ -276,6 +279,15 @@ class Client(cls): # type: ignore[valid-type, misc]
276279
return Client
277280

278281

282+
def _adjust_type(value: Any, labthings_typehint: str | None) -> Any:
283+
"""Adjust the return type based on labthings_typehint."""
284+
if labthings_typehint is None:
285+
return value
286+
if labthings_typehint == "ndarray":
287+
return np.array(value)
288+
raise ValueError(f"No type of {labthings_typehint} known")
289+
290+
279291
class PropertyClientDescriptor:
280292
"""A base class for properties on `.ThingClient` objects."""
281293

@@ -361,9 +373,12 @@ def add_action(cls: type[ThingClient], action_name: str, action: dict) -> None:
361373
:param action: a dictionary representing the action, in :ref:`wot_td`
362374
format.
363375
"""
376+
labthings_typehint = action["output"].get("format", None)
364377

365378
def action_method(self: ThingClient, **kwargs: Any) -> Any:
366-
return self.invoke_action(action_name, **kwargs)
379+
return self.invoke_action(
380+
action_name, labthings_typehint=labthings_typehint, **kwargs
381+
)
367382

368383
if "output" in action and "type" in action["output"]:
369384
action_method.__annotations__["return"] = action["output"]["type"]

src/labthings_fastapi/thing_description/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,12 @@ def type_to_dataschema(t: type, **kwargs: Any) -> DataSchema:
300300
:raise ValidationError: if the datatype cannot be represented
301301
by a `.DataSchema`.
302302
"""
303+
data_format = None
303304
if hasattr(t, "model_json_schema"):
304305
# The input should be a `BaseModel` subclass, in which case this works:
305306
json_schema = t.model_json_schema()
307+
if "_labthings_typehint" in t.__private_attributes__:
308+
data_format = t.__private_attributes__["_labthings_typehint"].default
306309
else:
307310
# In principle, the below should work for any type, though some
308311
# deferred annotations can go wrong.
@@ -319,6 +322,8 @@ def type_to_dataschema(t: type, **kwargs: Any) -> DataSchema:
319322
if k in schema_dict:
320323
del schema_dict[k]
321324
schema_dict.update(kwargs)
325+
if data_format is not None:
326+
schema_dict["format"] = data_format
322327
try:
323328
return DataSchema(**schema_dict)
324329
except ValidationError as ve:

src/labthings_fastapi/types/numpy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,15 @@ class DenumpifyingDict(RootModel):
147147

148148
root: Annotated[Mapping, WrapSerializer(denumpify_serializer)]
149149
model_config = ConfigDict(arbitrary_types_allowed=True)
150+
151+
152+
class ArrayModel(RootModel):
153+
"""A model automatically used by actions as the return type for a numpy array.
154+
155+
This models is passed to FastAPI as the return model for any action that returns
156+
a numpy array. The private typehint is saved as format information to allow
157+
a ThingClient to reconstruct the array from the list sent over HTTP.
158+
"""
159+
160+
root: NDArray
161+
_labthings_typehint: str = "ndarray"

src/labthings_fastapi/utilities/introspection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from pydantic import BaseModel, ConfigDict, Field, RootModel
1717
from pydantic.main import create_model
1818
from fastapi.dependencies.utils import analyze_param, get_typed_signature
19+
import numpy as np
20+
21+
from ..types.numpy import ArrayModel
1922

2023

2124
class EmptyObject(BaseModel):
@@ -178,6 +181,9 @@ def return_type(func: Callable) -> Type:
178181
else:
179182
# We use `get_type_hints` rather than just `sig.return_annotation`
180183
# because it resolves forward references, etc.
184+
rtype = get_type_hints(func)["return"]
185+
if isinstance(rtype, type) and issubclass(rtype, np.ndarray):
186+
return ArrayModel
181187
type_hints = get_type_hints(func, include_extras=True)
182188
return type_hints["return"]
183189

tests/test_numpy_type.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
from __future__ import annotations
22

3-
from pydantic import BaseModel, RootModel
3+
from pydantic import BaseModel
44
import numpy as np
55
from fastapi.testclient import TestClient
66

77
from labthings_fastapi.testing import create_thing_without_server
8-
from labthings_fastapi.types.numpy import NDArray, DenumpifyingDict
8+
from labthings_fastapi.types.numpy import NDArray, DenumpifyingDict, ArrayModel
99
import labthings_fastapi as lt
1010

1111

12-
class ArrayModel(RootModel):
13-
root: NDArray
14-
15-
1612
def check_field_works_with_list(data):
1713
class Model(BaseModel):
1814
a: NDArray

0 commit comments

Comments
 (0)