diff --git a/src/qumada/measurement/measurement.py b/src/qumada/measurement/measurement.py index 293a4ec6..496ccae6 100644 --- a/src/qumada/measurement/measurement.py +++ b/src/qumada/measurement/measurement.py @@ -22,6 +22,8 @@ from __future__ import annotations import copy +import functools +import importlib import inspect import json import logging @@ -42,6 +44,7 @@ from qumada.instrument.buffers import is_bufferable, is_triggerable from qumada.metadata import Metadata +from qumada.utils.liveplot import MeasurementAndPlot from qumada.utils.ramp_parameter import ramp_or_set_parameter from qumada.utils.utils import flatten_array @@ -103,6 +106,7 @@ class MeasurementScript(ABC): """ PARAMETER_NAMES: set[str] = load_param_whitelist() + DEFAULT_LIVE_PLOTTER: callable = None def __init__(self): # Create function hooks for metadata @@ -111,10 +115,27 @@ def __init__(self): self.run = create_hook(self.run, self._add_data_to_metadata) self.run = create_hook(self.run, self._add_current_datetime_to_metadata) + self.live_plotter = self.DEFAULT_LIVE_PLOTTER + self.properties: dict[Any, Any] = {} self.gate_parameters: dict[Any, dict[Any, Parameter | None] | Parameter | None] = {} self._buffered_num_points: int | None = None + def _new_measurement(self, name) -> MeasurementAndPlot: + return MeasurementAndPlot(name=name, gui=self.live_plotter) + + def _dond(self, *args, **kwargs): + """This is a wrapper around qcodes dond function that monkeypatches the live plotter in the datasaver""" + # we need to use importlib here because the dond function shadows the qcodes.dataset.dond package + do_nd = importlib.import_module("qcodes.dataset.dond.do_nd") + + prev_meas_cls = do_nd.Measurement + try: + do_nd.Measurement = functools.partial(MeasurementAndPlot, gui=self.live_plotter) + return do_nd.dond(*args, **kwargs) + finally: + do_nd.Measurement = prev_meas_cls + def add_gate_parameter(self, parameter_name: str, gate_name: str = None, parameter: Parameter = None) -> None: """ Adds a gate parameter to self.gate_parameters. diff --git a/src/qumada/measurement/scripts/generic_measurement.py b/src/qumada/measurement/scripts/generic_measurement.py index ecd665c3..0682315f 100644 --- a/src/qumada/measurement/scripts/generic_measurement.py +++ b/src/qumada/measurement/scripts/generic_measurement.py @@ -26,7 +26,6 @@ import numpy as np from qcodes.dataset import dond -from qcodes.dataset.measurements import Measurement from qcodes.parameters.specialized_parameters import ElapsedTimeParameter from qumada.instrument.buffers import is_bufferable @@ -96,7 +95,7 @@ def run(self, **dond_kwargs) -> list: self.initialize(inactive_dyn_channels=inactive_channels) sleep(wait_time) data.append( - dond( + self._dond( sweep, *measured_channels, measurement_name=self._measurement_name, @@ -154,7 +153,7 @@ def run(self, **dond_kwargs): for sweep in self.dynamic_sweeps: ramp_or_set_parameter(sweep._param, sweep.get_setpoints()[0]) sleep(wait_time) - data = dond( + data = self._dond( *tuple(self.dynamic_sweeps), *tuple(self.gettable_channels), measurement_name=measurement_name, @@ -248,7 +247,7 @@ def run(self): timestep = self.settings.get("timestep", 1) timer = ElapsedTimeParameter("time") naming_helper(self, default_name="Timetrace") - meas = Measurement(name=self.measurement_name) + meas = self._new_measurement(name=self.measurement_name) meas.register_parameter(timer) for parameter in [*self.gettable_channels, *self.dynamic_channels]: meas.register_parameter( @@ -331,7 +330,7 @@ def run(self): self.generate_lists() naming_helper(self, default_name="Timetrace") - meas = Measurement(name=self.measurement_name) + meas = self._new_measurement(name=self.measurement_name) meas.register_parameter(timer) for parameter in [*self.gettable_channels, *self.dynamic_channels]: @@ -443,7 +442,7 @@ def run(self): timestep = self.settings.get("timestep", 1) # backsweeps = self.settings.get("backsweeps", False) timer = ElapsedTimeParameter("time") - meas = Measurement(name=self.metadata.measurement.name or "timetrace") + meas = self._new_measurement(name=self.metadata.measurement.name or "timetrace") meas.register_parameter(timer) setpoints = [timer] for parameter in self.dynamic_channels: @@ -526,7 +525,7 @@ def run(self): datasets = [] self.generate_lists() naming_helper(self, default_name="Timetrace with sweeps") - meas = Measurement(name=self.measurement_name) + meas = self._new_measurement(name=self.measurement_name) meas.register_parameter(timer) for dynamic_param in self.dynamic_parameters: @@ -685,7 +684,7 @@ def run(self): dynamic_param = self.dynamic_sweeps[i].param inactive_channels = [chan for chan in self.dynamic_channels if chan != dynamic_param] self.initialize(inactive_dyn_channels=inactive_channels) - meas = Measurement(name=self.measurement_name) + meas = self._new_measurement(name=self.measurement_name) meas.register_parameter(dynamic_param) for c_param in self.active_compensating_channels: meas.register_parameter( @@ -890,7 +889,7 @@ def run(self): self.measurement_name += f" {dynamic_parameter['gate']}" self.properties[dynamic_parameter["gate"]][dynamic_parameter["parameter"]]["_is_triggered"] = True dynamic_param = dynamic_sweep.param - meas = Measurement(name=self.measurement_name) + meas = self._new_measurement(name=self.measurement_name) meas.register_parameter(dynamic_param) # This next block is required to log static and idle dynamic # parameters that cannot be buffered. @@ -1083,7 +1082,7 @@ def run(self): gate_names = [gate["gate"] for gate in self.dynamic_parameters] self.measurement_name += f" {gate_names}" - meas = Measurement(name=self.measurement_name) + meas = self._new_measurement(name=self.measurement_name) if reverse_param_order: slow_param = self.dynamic_parameters[1] @@ -1339,7 +1338,7 @@ def run(self): gate_names = [gate["gate"] for gate in self.dynamic_parameters] self.measurement_name += f" {gate_names}" - meas = Measurement(name=self.measurement_name) + meas = self._new_measurement(name=self.measurement_name) meas.register_parameter(timer) for parameter in self.dynamic_parameters: self.properties[parameter["gate"]][parameter["parameter"]]["_is_triggered"] = True @@ -1542,7 +1541,7 @@ def run(self): gate_names = [gate["gate"] for gate in self.dynamic_parameters] self.measurement_name += f" {gate_names}" - meas = Measurement(name=self.measurement_name) + meas = self._new_measurement(name=self.measurement_name) meas.register_parameter(timer) for parameter in self.dynamic_parameters: self.properties[parameter["gate"]][parameter["parameter"]]["_is_triggered"] = True diff --git a/src/qumada/utils/liveplot.py b/src/qumada/utils/liveplot.py new file mode 100644 index 00000000..8a27312b --- /dev/null +++ b/src/qumada/utils/liveplot.py @@ -0,0 +1,49 @@ +import contextlib +import functools +from collections.abc import Sequence +from typing import Optional, Protocol, Union + +from qcodes import Measurement +from qcodes.dataset.data_set import DataSet +from qcodes.parameters import ParameterBase + + +class MeasurementAndPlot: + def __init__(self, *, name: str, gui=None, **kwargs): + self.qcodes_measurement = Measurement(name=name, **kwargs) + self.gui = gui + + def register_parameter( + self, parameter: ParameterBase, setpoints: Optional[Sequence[Union[str, ParameterBase]]] = None, **kwargs + ): + self.qcodes_measurement.register_parameter(parameter, setpoints, **kwargs) + + def set_shapes(self, shapes): + self.qcodes_measurement.set_shapes(shapes=shapes) + + @contextlib.contextmanager + def run(self, **kwargs): + if self.gui is not None: + # here we could add some more arguments in the future + plot_target = self.gui + else: + plot_target = None + + with self.qcodes_measurement.run(**kwargs) as qcodes_datasaver: + yield DataSaverAndPlotter(self, qcodes_datasaver, plot_target) + + +class DataSaverAndPlotter: + def __init__(self, parent: MeasurementAndPlot, qcodes_datasaver, plot_target: callable): + self._parent = parent + self.qcodes_datasaver = qcodes_datasaver + self.plot_target = plot_target + + def add_result(self, *args): + self.qcodes_datasaver.add_result(*args) + if self.plot_target is not None: + self.plot_target(self.dataset.to_xarray_dataset()) + + @property + def dataset(self) -> DataSet: + return self.qcodes_datasaver.dataset diff --git a/src/tests/conftest.py b/src/tests/conftest.py new file mode 100644 index 00000000..c3b52717 --- /dev/null +++ b/src/tests/conftest.py @@ -0,0 +1,55 @@ +import dataclasses +import pathlib +import tempfile +import threading +import time + +import pytest +from qcodes.dataset.experiment_container import load_or_create_experiment +from qcodes.station import Station + +from qumada.instrument.buffered_instruments import BufferedDummyDMM as DummyDmm +from qumada.instrument.custom_drivers.Dummies.dummy_dac import DummyDac +from qumada.instrument.mapping import ( + DUMMY_DMM_MAPPING, + add_mapping_to_instrument, +) +from qumada.instrument.mapping.Dummies.DummyDac import DummyDacMapping +from qumada.utils.load_from_sqlite_db import load_db + + +@dataclasses.dataclass +class MeasurementTestSetup: + trigger: threading.Event + + station: Station + dmm: DummyDmm + dac: DummyDac + + db_path: pathlib.Path + + +@pytest.fixture +def measurement_test_setup(tmp_path): + trigger = threading.Event() + + # Setup qcodes station + station = Station() + + # The dummy instruments have a trigger_event attribute as replacement for + # the trigger inputs of real instruments. + + dmm = DummyDmm("dmm", trigger_event=trigger) + add_mapping_to_instrument(dmm, mapping=DUMMY_DMM_MAPPING) + station.add_component(dmm) + + dac = DummyDac("dac", trigger_event=trigger) + add_mapping_to_instrument(dac, mapping=DummyDacMapping()) + station.add_component(dac) + + db_path = tmp_path / "test.db" + load_db(str(db_path)) + load_or_create_experiment("test", "dummy_sample") + + yield MeasurementTestSetup(trigger, station, dmm, dac, db_path) + station.close_all_registered_instruments() diff --git a/src/tests/device_test.py b/src/tests/device_test.py new file mode 100644 index 00000000..67022fc1 --- /dev/null +++ b/src/tests/device_test.py @@ -0,0 +1,132 @@ +import dataclasses +import itertools + +import numpy as np +import pytest + +from qumada.measurement.device_object import QumadaDevice + +from .conftest import MeasurementTestSetup + + +@dataclasses.dataclass +class DeviceTestSetup: + measurement_test_setup: MeasurementTestSetup + device: QumadaDevice + parameters: dict + namespace: dict + + +@pytest.fixture +def device_test_setup(measurement_test_setup): + """This fixture is derived from device_object_example""" + + parameters = { + "ohmic": { + "current": {"type": "gettable"}, + }, + "gate1": {"voltage": {"type": "static"}}, + "gate2": {"voltage": {"type": "static"}}, + } + namespace = {} + device = QumadaDevice.create_from_dict(parameters, station=measurement_test_setup.station, namespace=namespace) + + buffer_settings = { + "sampling_rate": 512, + "num_points": 12, + "delay": 0, + } + + mapping = { + "ohmic": { + "current": measurement_test_setup.dmm.current, + }, + "gate1": { + "voltage": measurement_test_setup.dac.ch01.voltage, + }, + "gate2": { + "voltage": measurement_test_setup.dac.ch02.voltage, + }, + } + + # This tells a measurement script how to start a buffered measurement. + # "Hardware" means that you want to use a hardware trigger. To start a measurement, + # the method provided as "trigger_start" is called. The "trigger_reset" method is called + # at the end of each buffered line, in our case resetting the trigger flag. + # For real instruments, you might have to define a method that sets the output of your instrument + # to a desired value as "trigger_start". For details on other ways to setup your triggers, + # check the documentation. + + buffer_script_settings = { + "trigger_type": "hardware", + "trigger_start": measurement_test_setup.trigger.set, + "trigger_reset": measurement_test_setup.trigger.clear, + } + + device.buffer_script_setup = buffer_script_settings + device.buffer_settings = buffer_settings + + # device.mapping() + # - map_terminals_gui(self.station.components, self.instrument_parameters, instrument_parameters) + device.instrument_parameters = mapping + # - self.update_terminal_parameters() + device.update_terminal_parameters() + + # map_triggers(station.components) ??? + measurement_test_setup.dac._qumada_mapping.trigger_in = None + (measurement_test_setup.dmm._qumada_buffer.trigger,) = measurement_test_setup.dmm._qumada_buffer.AVAILABLE_TRIGGERS + + return DeviceTestSetup( + measurement_test_setup, + device, + parameters, + namespace, + ) + + +@pytest.mark.parametrize( + "buffered,backsweep", + itertools.product( + # buffered + [True, False], + # backsweep + [False, True], + ), +) +def test_measured_ramp(device_test_setup, buffered, backsweep): + gate1 = device_test_setup.namespace["gate1"] + + plot_args = [] + + def plot_backend(*args, **kwargs): + plot_args.append((args, kwargs)) + + from qumada.measurement.measurement import MeasurementScript + + MeasurementScript.DEFAULT_LIVE_PLOTTER = plot_backend + + (qcodes_data,) = gate1.voltage.measured_ramp(0.4, start=-0.3, buffered=buffered, backsweep=backsweep) + if backsweep: + assert gate1.voltage() == pytest.approx(-0.3, abs=0.001) + else: + assert gate1.voltage() == pytest.approx(0.4, abs=0.001) + + if not buffered: + # TODO: Why is this necessary??? + (qcodes_data, _, _) = qcodes_data + xarr = qcodes_data.to_xarray_dataset() + + set_points = xarr.dac_ch01_voltage.values + + if backsweep: + fwd = np.linspace(-0.3, 0.4, len(set_points) // 2) + expected = np.concatenate((fwd, fwd[::-1])) + else: + expected = np.linspace(-0.3, 0.4, len(set_points)) + + if buffered: + assert len(plot_args) == 1 + int(backsweep) + else: + assert len(plot_args) == len(set_points) + + np.testing.assert_almost_equal(expected, set_points) diff --git a/src/tests/mapping_test.py b/src/tests/mapping_test.py index 2db714d5..14274ff8 100644 --- a/src/tests/mapping_test.py +++ b/src/tests/mapping_test.py @@ -52,34 +52,38 @@ from qumada.measurement.scripts.generic_measurement import Generic_1D_Sweep -@pytest.fixture(name="dmm", scope="session") +@pytest.fixture(name="dmm") def fixture_dmm(): dmm = DummyDmm("dmm") add_mapping_to_instrument(dmm, mapping=mapping.DUMMY_DMM_MAPPING) - return dmm + yield dmm + dmm.close() -@pytest.fixture(name="dac", scope="session") +@pytest.fixture(name="dac") def fixture_dac(): dac = DummyDac("dac") add_mapping_to_instrument(dac, mapping=DummyDacMapping()) - return dac + yield dac + dac.close() -@pytest.fixture(name="dci", scope="session") +@pytest.fixture(name="dci") def fixture_dci(): dci = DummyChannelInstrument("dci") add_mapping_to_instrument(dci, mapping=mapping.DUMMY_CHANNEL_MAPPING) - return dci + yield dci + dci.close() -@pytest.fixture(name="station_with_instruments", scope="session") +@pytest.fixture(name="station_with_instruments") def fixture_station_with_instruments(dmm, dac, dci): station = Station() station.add_component(dmm) station.add_component(dac) station.add_component(dci) - return station + yield station + station.close_all_registered_instruments() @pytest.fixture(name="script") diff --git a/src/tests/measurement_test.py b/src/tests/measurement_test.py new file mode 100644 index 00000000..2b106423 --- /dev/null +++ b/src/tests/measurement_test.py @@ -0,0 +1,102 @@ +import dataclasses +import tempfile +import threading + +import numpy as np +import pytest +import yaml +from qcodes.dataset import ( + Measurement, + experiments, + initialise_or_create_database_at, + load_by_run_spec, + load_or_create_experiment, +) +from qcodes.station import Station + +from qumada.instrument.buffered_instruments import BufferedDummyDMM as DummyDmm +from qumada.instrument.buffers.buffer import ( + load_trigger_mapping, + map_triggers, + save_trigger_mapping, +) +from qumada.instrument.custom_drivers.Dummies.dummy_dac import DummyDac +from qumada.instrument.mapping import ( + DUMMY_DMM_MAPPING, + add_mapping_to_instrument, + map_terminals_gui, +) +from qumada.instrument.mapping.Dummies.DummyDac import DummyDacMapping +from qumada.measurement.scripts import ( + Generic_1D_parallel_asymm_Sweep, + Generic_1D_parallel_Sweep, + Generic_1D_Sweep, + Generic_1D_Sweep_buffered, + Generic_2D_Sweep_buffered, + Generic_nD_Sweep, + Timetrace, +) +from qumada.utils.generate_sweeps import generate_sweep, replace_parameter_settings +from qumada.utils.GUI import open_web_gui +from qumada.utils.load_from_sqlite_db import load_db +from qumada.utils.ramp_parameter import * + + +@pytest.fixture +def buffer_settings(): + return { + "sampling_rate": 512, + "duration": 12 / 512, + "burst_duration": 12 / 512, + "delay": 0, + } + + +@pytest.fixture +def parameters(): + return { + "ohmic": { + "voltage": {"type": "gettable"}, + "current": {"type": "gettable"}, + }, + "gate1": {"voltage": {"type": "dynamic", "setpoints": np.linspace(0, np.pi, 12), "value": 0}}, + "gate2": {"voltage": {"type": "dynamic", "setpoints": np.linspace(0, np.pi, 12), "value": 0}}, + } + + +def test_1d_buffered(measurement_test_setup, buffer_settings, parameters): + script = Generic_1D_Sweep_buffered() + script.setup( + parameters, + metadata=None, + buffer_settings=buffer_settings, + trigger_type="hardware", + trigger_start=measurement_test_setup.trigger.set, + trigger_reset=measurement_test_setup.trigger.clear, + ) + + mapping = { + "ohmic": { + "voltage": measurement_test_setup.dmm.voltage, + "current": measurement_test_setup.dmm.current, + }, + "gate1": { + "voltage": measurement_test_setup.dac.ch01.voltage, + }, + "gate2": { + "voltage": measurement_test_setup.dac.ch02.voltage, + }, + } + script.gate_parameters = mapping + ds1, ds2 = script.run() + ds1 = ds1.to_xarray_dataset() + ds2 = ds2.to_xarray_dataset() + + np.testing.assert_almost_equal( + parameters["gate1"]["voltage"]["setpoints"], + ds1.dac_ch01_voltage.values, + ) + np.testing.assert_almost_equal( + parameters["gate2"]["voltage"]["setpoints"], + ds2.dac_ch02_voltage.values, + )