diff --git a/drjax/__init__.py b/drjax/__init__.py index 95f1dbf..1cecb25 100644 --- a/drjax/__init__.py +++ b/drjax/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. """The DrJAX package.""" +import collections as _collections import functools as _functools import sys as _sys @@ -29,10 +30,30 @@ @_functools.wraps(_api.drjax_program) -def program(*, placements): +def program(**kwargs): + """A decorator enabling calling the DrJAX API.""" + try: + [(placement, value)] = kwargs.items() # pylint: disable=unbalanced-dict-unpacking + except ValueError as e: + raise ValueError( + f'The program API expects a single keyword argument but got {kwargs=}.' + ) from e # We wrap here and send in this module as the one to be modified, as it # will be the one that users interact with and requires the API changes. - return _api.drjax_program( - placements=placements, - self_module=_sys.modules[__name__], - ) + if placement == 'placements': + if not isinstance(value, _collections.abc.Mapping): + # pylint: disable=f-string-without-interpolation + raise ValueError( + f'When using the `placements` argument, value must be a mapping but' + f' got {{type(value)=}}. `placements` is a reserved name and not a' + f' validate placement name.' + ) + # pylint: enable=f-string-without-interpolation + return _api.drjax_program( + placements=value, self_module=_sys.modules[__name__] + ) + else: + return _api.drjax_program( + placements=kwargs, + self_module=_sys.modules[__name__], + )