Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions qupulse/pulses/arithmetic_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from qupulse.expressions import ExpressionScalar, ExpressionLike
from qupulse.program import ProgramBuilder
from qupulse.pulses.metadata import TemplateMetadata
from qupulse.serialization import Serializer, PulseRegistryType
from qupulse.parameter_scope import Scope

Expand Down Expand Up @@ -46,7 +47,9 @@ def __init__(self,
silent_atomic: bool = False,
measurements: List = None,
identifier: str = None,
registry: PulseRegistryType = None):
registry: PulseRegistryType = None,
metadata: TemplateMetadata | dict = None,
):
"""Apply an operation (+ or -) channel wise to two atomic pulse templates. Channels only present in one pulse
template have the operations neutral element on the other. The operations are defined in
`ArithmeticWaveform.operator_map`.
Expand All @@ -61,7 +64,7 @@ def __init__(self,
identifier: See AtomicPulseTemplate
registry: See qupulse.serialization.PulseRegistry
"""
super().__init__(identifier=identifier, measurements=measurements)
super().__init__(identifier=identifier, measurements=measurements, metadata=metadata)

if arithmetic_operator not in ArithmeticWaveform.operator_map:
raise ValueError('Unknown operator. allowed: %r' % set(ArithmeticWaveform.operator_map.keys()))
Expand Down Expand Up @@ -204,7 +207,9 @@ def __init__(self,
rhs: Union[PulseTemplate, ExpressionLike, Mapping[ChannelID, ExpressionLike]],
*,
identifier: Optional[str] = None,
registry: PulseRegistryType = None):
registry: PulseRegistryType = None,
metadata: TemplateMetadata | dict = None,
):
"""Implements the arithmetics between an arbitrary pulse template and scalar values. The values can be the same
for all channels, channel specific or only for a subset of the inner pulse templates defined channels.
The expression may be time-dependent if the pulse template is atomic.
Expand All @@ -231,7 +236,7 @@ def __init__(self,
and a composite pulse template.
ValueError: If the scalar is a mapping and contains channels that are not defined on the pulse template.
"""
PulseTemplate.__init__(self, identifier=identifier)
PulseTemplate.__init__(self, identifier=identifier, metadata=metadata)

if not isinstance(lhs, PulseTemplate) and not isinstance(rhs, PulseTemplate):
raise TypeError('At least one of the operands needs to be a pulse template.')
Expand Down
21 changes: 17 additions & 4 deletions tests/pulses/arithmetic_pulse_template_tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
import unittest
from unittest import mock
import warnings
Expand Down Expand Up @@ -206,10 +207,6 @@ def make_kwargs(self):
'measurements': [('m1', 0., .1)]
}

def make_instance(self, identifier=None, registry=None):
kwargs = self.make_kwargs()
return self.class_to_test(identifier=identifier, **kwargs, registry=registry)

def assert_equal_instance_except_id(self, lhs: ArithmeticAtomicPulseTemplate, rhs: ArithmeticAtomicPulseTemplate):
self.assertIsInstance(lhs, ArithmeticAtomicPulseTemplate)
self.assertIsInstance(rhs, ArithmeticAtomicPulseTemplate)
Expand Down Expand Up @@ -681,3 +678,19 @@ def test_offset(self):
_ = self.complex_pt - '4.5'


class ArithmeticPulseTemplateSerializationTest(SerializableTests, unittest.TestCase):
def assert_equal_instance_except_id(self, lhs: ArithmeticPulseTemplate, rhs: ArithmeticPulseTemplate):
self.assertEqual(lhs.lhs, rhs.lhs)
self.assertEqual(lhs.rhs, rhs.rhs)
self.assertEqual(lhs._arithmetic_operator, rhs._arithmetic_operator)

@property
def class_to_test(self) -> typing.Any:
return ArithmeticPulseTemplate

def make_kwargs(self) -> dict:
return {
"lhs": 42.2,
"rhs": DummyPulseTemplate(),
"arithmetic_operator": "+",
}
Loading