From 3118394cb6759761bafe54d68501e901e7c28d21 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 28 Jan 2026 14:47:19 +0100 Subject: [PATCH] Add tests for metadata deserialization and implement it for arithmetic pulse templates --- qupulse/pulses/arithmetic_pulse_template.py | 13 ++++++++---- .../pulses/arithmetic_pulse_template_tests.py | 21 +++++++++++++++---- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/qupulse/pulses/arithmetic_pulse_template.py b/qupulse/pulses/arithmetic_pulse_template.py index 0ac2fce8..e2cc1c00 100644 --- a/qupulse/pulses/arithmetic_pulse_template.py +++ b/qupulse/pulses/arithmetic_pulse_template.py @@ -11,6 +11,7 @@ from qupulse.expressions import ExpressionScalar, ExpressionLike from qupulse.program import ProgramBuilder +from qupulse.pulses.metadata import TemplateMetadata from qupulse.serialization import Serializer, PulseRegistryType from qupulse.parameter_scope import Scope @@ -46,7 +47,9 @@ def __init__(self, silent_atomic: bool = False, measurements: List = None, identifier: str = None, - registry: PulseRegistryType = None): + registry: PulseRegistryType = None, + metadata: TemplateMetadata | dict = None, + ): """Apply an operation (+ or -) channel wise to two atomic pulse templates. Channels only present in one pulse template have the operations neutral element on the other. The operations are defined in `ArithmeticWaveform.operator_map`. @@ -61,7 +64,7 @@ def __init__(self, identifier: See AtomicPulseTemplate registry: See qupulse.serialization.PulseRegistry """ - super().__init__(identifier=identifier, measurements=measurements) + super().__init__(identifier=identifier, measurements=measurements, metadata=metadata) if arithmetic_operator not in ArithmeticWaveform.operator_map: raise ValueError('Unknown operator. allowed: %r' % set(ArithmeticWaveform.operator_map.keys())) @@ -204,7 +207,9 @@ def __init__(self, rhs: Union[PulseTemplate, ExpressionLike, Mapping[ChannelID, ExpressionLike]], *, identifier: Optional[str] = None, - registry: PulseRegistryType = None): + registry: PulseRegistryType = None, + metadata: TemplateMetadata | dict = None, + ): """Implements the arithmetics between an arbitrary pulse template and scalar values. The values can be the same for all channels, channel specific or only for a subset of the inner pulse templates defined channels. The expression may be time-dependent if the pulse template is atomic. @@ -231,7 +236,7 @@ def __init__(self, and a composite pulse template. ValueError: If the scalar is a mapping and contains channels that are not defined on the pulse template. """ - PulseTemplate.__init__(self, identifier=identifier) + PulseTemplate.__init__(self, identifier=identifier, metadata=metadata) if not isinstance(lhs, PulseTemplate) and not isinstance(rhs, PulseTemplate): raise TypeError('At least one of the operands needs to be a pulse template.') diff --git a/tests/pulses/arithmetic_pulse_template_tests.py b/tests/pulses/arithmetic_pulse_template_tests.py index 8c9aac2d..3fe3b530 100644 --- a/tests/pulses/arithmetic_pulse_template_tests.py +++ b/tests/pulses/arithmetic_pulse_template_tests.py @@ -1,3 +1,4 @@ +import typing import unittest from unittest import mock import warnings @@ -206,10 +207,6 @@ def make_kwargs(self): 'measurements': [('m1', 0., .1)] } - def make_instance(self, identifier=None, registry=None): - kwargs = self.make_kwargs() - return self.class_to_test(identifier=identifier, **kwargs, registry=registry) - def assert_equal_instance_except_id(self, lhs: ArithmeticAtomicPulseTemplate, rhs: ArithmeticAtomicPulseTemplate): self.assertIsInstance(lhs, ArithmeticAtomicPulseTemplate) self.assertIsInstance(rhs, ArithmeticAtomicPulseTemplate) @@ -681,3 +678,19 @@ def test_offset(self): _ = self.complex_pt - '4.5' +class ArithmeticPulseTemplateSerializationTest(SerializableTests, unittest.TestCase): + def assert_equal_instance_except_id(self, lhs: ArithmeticPulseTemplate, rhs: ArithmeticPulseTemplate): + self.assertEqual(lhs.lhs, rhs.lhs) + self.assertEqual(lhs.rhs, rhs.rhs) + self.assertEqual(lhs._arithmetic_operator, rhs._arithmetic_operator) + + @property + def class_to_test(self) -> typing.Any: + return ArithmeticPulseTemplate + + def make_kwargs(self) -> dict: + return { + "lhs": 42.2, + "rhs": DummyPulseTemplate(), + "arithmetic_operator": "+", + }