diff --git a/qupulse/pulses/abstract_pulse_template.py b/qupulse/pulses/abstract_pulse_template.py index 54e21ae0..e26500d3 100644 --- a/qupulse/pulses/abstract_pulse_template.py +++ b/qupulse/pulses/abstract_pulse_template.py @@ -8,6 +8,7 @@ from qupulse import ChannelID from qupulse.expressions import ExpressionScalar +from qupulse.pulses.metadata import TemplateMetadata from qupulse.serialization import PulseRegistryType from qupulse.pulses.pulse_template import PulseTemplate @@ -81,6 +82,11 @@ def __init__(self, identifier: str, self._register(registry=registry) + @property + def metadata(self) -> TemplateMetadata: + raise NotImplementedError('AbstractPulseTemplate does not support metadata yet. ' + 'Please file an issue if you need this feature.') + def link_to(self, target: PulseTemplate, serialize_linked: bool=None): """Link to another pulse template. diff --git a/qupulse/pulses/constant_pulse_template.py b/qupulse/pulses/constant_pulse_template.py index 586cf197..39749269 100644 --- a/qupulse/pulses/constant_pulse_template.py +++ b/qupulse/pulses/constant_pulse_template.py @@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Union, Mapping, AbstractSet from qupulse.program.waveforms import ConstantWaveform +from qupulse.pulses.metadata import TemplateMetadata from qupulse.utils.types import TimeType, ChannelID from qupulse.utils import cached_property from qupulse.expressions import ExpressionScalar, ExpressionLike @@ -27,7 +28,9 @@ def __init__(self, duration: ExpressionLike, identifier: Optional[str] = None, name: Optional[str] = None, measurements: Optional[List[MeasurementDeclaration]] = None, - registry: PulseRegistryType=None) -> None: + registry: PulseRegistryType=None, + metadata: TemplateMetadata | dict = None + ) -> None: """An atomic pulse template qupulse representing a multi-channel pulse with constant values. As an optimization, this class does not convert plain floats or ints to qupulse expressions. @@ -40,7 +43,7 @@ def __init__(self, duration: ExpressionLike, measurements: Passed to :py:class:`.MeasurementDefiner` superclass registry: The pulse is registered in this mapping after construction if an identifier is provided """ - super().__init__(identifier=identifier, measurements=measurements) + super().__init__(identifier=identifier, measurements=measurements, metadata=metadata) # we special case numeric values in this PulseTemplate for performance reasons self._duration = duration if isinstance(duration, (float, int, TimeType)) else ExpressionScalar(duration) diff --git a/qupulse/pulses/function_pulse_template.py b/qupulse/pulses/function_pulse_template.py index 4a31e578..8796a181 100644 --- a/qupulse/pulses/function_pulse_template.py +++ b/qupulse/pulses/function_pulse_template.py @@ -13,6 +13,7 @@ import sympy from qupulse.expressions import ExpressionScalar +from qupulse.pulses.metadata import TemplateMetadata from qupulse.serialization import Serializer, PulseRegistryType from qupulse.utils.types import ChannelID, TimeType, time_from_float @@ -44,7 +45,9 @@ def __init__(self, *, measurements: Optional[List[MeasurementDeclaration]]=None, parameter_constraints: Optional[List[Union[str, ParameterConstraint]]]=None, - registry: PulseRegistryType=None) -> None: + registry: PulseRegistryType=None, + metadata: TemplateMetadata | dict = None, + ) -> None: """ Args: expression: The function represented by this FunctionPulseTemplate @@ -61,7 +64,7 @@ def __init__(self, :py:class:`~qupulse.pulses.measurement.ParameterConstrainer` superclass registry: After initialization this pulse is registered in the given mapping if an identifier is provided. """ - AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements) + AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements, metadata=metadata) ParameterConstrainer.__init__(self, parameter_constraints=parameter_constraints) self.__expression = ExpressionScalar.make(expression) diff --git a/qupulse/pulses/multi_channel_pulse_template.py b/qupulse/pulses/multi_channel_pulse_template.py index 46c19e0b..ace764cf 100644 --- a/qupulse/pulses/multi_channel_pulse_template.py +++ b/qupulse/pulses/multi_channel_pulse_template.py @@ -16,6 +16,7 @@ import warnings from qupulse.program import ProgramBuilder +from qupulse.pulses.metadata import TemplateMetadata from qupulse.serialization import Serializer, PulseRegistryType from qupulse.parameter_scope import Scope @@ -40,7 +41,9 @@ def __init__(self, parameter_constraints: Optional[List] = None, measurements: Optional[List[MeasurementDeclaration]] = None, registry: PulseRegistryType = None, - duration: Optional[ExpressionLike] = None) -> None: + duration: Optional[ExpressionLike] = None, + metadata: TemplateMetadata | dict = None, + ) -> None: """Combines multiple AtomicPulseTemplates of the same duration that are defined on different channels into an AtomicPulseTemplate. If the duration keyword argument is given it is enforced that the instantiated pulse template has this duration. @@ -55,7 +58,7 @@ def __init__(self, duration: Enforced duration of the pulse template on instantiation. build_waveform checks all sub-waveforms have this duration. If True the equality of durations is only checked durtin instantiation not construction. """ - AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements) + AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements, metadata=metadata) ParameterConstrainer.__init__(self, parameter_constraints=parameter_constraints) self._subtemplates = [st if isinstance(st, PulseTemplate) else MappingPulseTemplate.from_tuple(st) for st in @@ -227,7 +230,9 @@ def __init__(self, template: PulseTemplate, overwritten_channels: Mapping[ChannelID, Union[ExpressionScalar, Sympifyable]], *, identifier: Optional[str]=None, - registry: Optional[PulseRegistryType] = None): + registry: Optional[PulseRegistryType] = None, + metadata: TemplateMetadata | dict = None, + ): """Pulse template to add new or overwrite existing channels of a contained pulse template. The channel values may be time dependent if the contained pulse template is atomic. @@ -238,7 +243,7 @@ def __init__(self, identifier: Name of the pulse template for serialization registry: Pulse template gets registered here if not None. """ - super().__init__(identifier=identifier) + super().__init__(identifier=identifier, metadata=metadata) self._template = template self._overwritten_channels = {channel: ExpressionScalar(value) diff --git a/qupulse/pulses/point_pulse_template.py b/qupulse/pulses/point_pulse_template.py index 6d0eaf95..c25c51f3 100644 --- a/qupulse/pulses/point_pulse_template.py +++ b/qupulse/pulses/point_pulse_template.py @@ -10,6 +10,7 @@ import sympy import numpy as np +from qupulse.pulses.metadata import TemplateMetadata from qupulse.utils.sympy import IndexedBroadcast from qupulse.utils.types import ChannelID from qupulse.expressions import Expression, ExpressionScalar @@ -50,9 +51,11 @@ def __init__(self, parameter_constraints: Optional[List[Union[str, ParameterConstraint]]]=None, measurements: Optional[List[MeasurementDeclaration]]=None, identifier: Optional[str]=None, - registry: PulseRegistryType=None) -> None: + registry: PulseRegistryType=None, + metadata: TemplateMetadata | dict = None, + ) -> None: - AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements) + AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements, metadata=metadata) ParameterConstrainer.__init__(self, parameter_constraints=parameter_constraints) self._channels = tuple(channel_names) diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index 9c2b0bc3..5fff90b5 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -681,8 +681,9 @@ class AtomicPulseTemplate(PulseTemplate, MeasurementDefiner): def __init__(self, *, identifier: Optional[str], - measurements: Optional[List[MeasurementDeclaration]]): - PulseTemplate.__init__(self, identifier=identifier) + measurements: Optional[List[MeasurementDeclaration]], + metadata: TemplateMetadata | dict = None): + PulseTemplate.__init__(self, identifier=identifier, metadata=metadata) MeasurementDefiner.__init__(self, measurements=measurements) def with_parallel_atomic(self, *parallel: 'AtomicPulseTemplate') -> 'AtomicPulseTemplate': diff --git a/qupulse/pulses/table_pulse_template.py b/qupulse/pulses/table_pulse_template.py index 60ac59cf..785cad72 100644 --- a/qupulse/pulses/table_pulse_template.py +++ b/qupulse/pulses/table_pulse_template.py @@ -20,6 +20,7 @@ import sympy from sympy.logic.boolalg import BooleanAtom +from qupulse.pulses.metadata import TemplateMetadata from qupulse.utils import pairwise from qupulse.utils.types import ChannelID from qupulse.serialization import Serializer, PulseRegistryType @@ -153,7 +154,9 @@ def __init__(self, entries: Dict[ChannelID, Sequence[EntryInInit]], parameter_constraints: Optional[List[Union[str, ParameterConstraint]]]=None, measurements: Optional[List[MeasurementDeclaration]]=None, consistency_check: bool=True, - registry: PulseRegistryType=None) -> None: + registry: PulseRegistryType=None, + metadata: TemplateMetadata | dict = None, + ) -> None: """ Construct a `TablePulseTemplate` from a dict which maps channels to their entries. By default the consistency of the provided entries is checked. There are two static functions for convenience construction: from_array and @@ -167,7 +170,7 @@ def __init__(self, entries: Dict[ChannelID, Sequence[EntryInInit]], measurements: Measurement declaration list that is forwarded to the MeasurementDefiner superclass consistency_check: If True the consistency of the times will be checked on construction as far as possible """ - AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements) + AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements, metadata=metadata) ParameterConstrainer.__init__(self, parameter_constraints=parameter_constraints) if not entries: diff --git a/qupulse/pulses/time_reversal_pulse_template.py b/qupulse/pulses/time_reversal_pulse_template.py index e0b9376c..ebaca7f2 100644 --- a/qupulse/pulses/time_reversal_pulse_template.py +++ b/qupulse/pulses/time_reversal_pulse_template.py @@ -7,6 +7,7 @@ from qupulse import ChannelID from qupulse.program import ProgramBuilder from qupulse.program.waveforms import Waveform +from qupulse.pulses.metadata import TemplateMetadata from qupulse.serialization import PulseRegistryType from qupulse.expressions import ExpressionScalar @@ -18,8 +19,10 @@ class TimeReversalPulseTemplate(PulseTemplate): def __init__(self, inner: PulseTemplate, identifier: Optional[str] = None, - registry: PulseRegistryType = None): - super(TimeReversalPulseTemplate, self).__init__(identifier=identifier) + registry: PulseRegistryType = None, + metadata: TemplateMetadata | dict = None, + ): + super(TimeReversalPulseTemplate, self).__init__(identifier=identifier, metadata=metadata) self._inner = inner self._register(registry=registry) diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 5518c023..4e7ed09a 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -156,8 +156,9 @@ def __init__(self, final_values: Dict[ChannelID, Any]=None, program: Optional[Loop]=None, identifier=None, + metadata=None, registry=None) -> None: - super().__init__(identifier=identifier, measurements=measurements) + super().__init__(identifier=identifier, measurements=measurements, metadata=metadata) self.requires_stop_ = requires_stop self.requires_stop_arguments = [] diff --git a/tests/serialization_tests.py b/tests/serialization_tests.py index a49ae7e6..e980c096 100644 --- a/tests/serialization_tests.py +++ b/tests/serialization_tests.py @@ -25,9 +25,16 @@ class DummySerializable(Serializable): - def __init__(self, identifier: Optional[str]=None, registry: PulseRegistryType=None, **kwargs) -> None: + def __init__(self, + identifier: Optional[str]=None, + registry: PulseRegistryType=None, + **kwargs) -> None: super().__init__(identifier) for name in kwargs: + if name == "metadata": + # this was the easiest way to adjust this test + # metadata is not a Serializable but a PulseTemplate attribute so it is a hacky solution + continue setattr(self, name, kwargs[name]) self._register(registry=registry) @@ -79,13 +86,15 @@ def assert_equal_instance(self, lhs, rhs): def assert_equal_instance_except_id(self, lhs, rhs): pass - def make_instance(self, identifier=None, registry=None): - return self.class_to_test(identifier=identifier, registry=registry, **self.make_kwargs()) + def make_instance(self, identifier=None, registry=None, metadata=None): + return self.class_to_test(identifier=identifier, registry=registry, metadata=metadata, **self.make_kwargs()) - def make_serialization_data(self, identifier=None): + def make_serialization_data(self, identifier=None, metadata=None): data = {Serializable.type_identifier_name: self.class_to_test.get_type_identifier(), **self.make_kwargs()} if identifier: data[Serializable.identifier_name] = identifier + if metadata: + data["metadata"] = metadata return data def test_identifier(self) -> None: