Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions drjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""The DrJAX package."""

import collections as _collections
import functools as _functools
import sys as _sys

Expand All @@ -29,10 +30,30 @@


@_functools.wraps(_api.drjax_program)
def program(*, placements):
def program(**kwargs):
"""A decorator enabling calling the DrJAX API."""
try:
[(placement, value)] = kwargs.items() # pylint: disable=unbalanced-dict-unpacking
except ValueError as e:
raise ValueError(
f'The program API expects a single keyword argument but got {kwargs=}.'
) from e
# We wrap here and send in this module as the one to be modified, as it
# will be the one that users interact with and requires the API changes.
return _api.drjax_program(
placements=placements,
self_module=_sys.modules[__name__],
)
if placement == 'placements':
if not isinstance(value, _collections.abc.Mapping):
# pylint: disable=f-string-without-interpolation
raise ValueError(
f'When using the `placements` argument, value must be a mapping but'
f' got {{type(value)=}}. `placements` is a reserved name and not a'
f' validate placement name.'
)
# pylint: enable=f-string-without-interpolation
return _api.drjax_program(
placements=value, self_module=_sys.modules[__name__]
)
else:
return _api.drjax_program(
placements=kwargs,
self_module=_sys.modules[__name__],
)
Loading