From 5191e2fbed280a27bb4b396f4a831f6ef3383eca Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 20 Jun 2025 17:16:32 +0100 Subject: [PATCH 1/5] feat!: Remove non-cli logic from cli.run_plan Argument parsing has been moved into a click option validator, and task creation has been moved into the client create_task methods. This is a breaking change in the BlueapiClient as task methods now accept plan name and parameters as separate arguments. --- src/blueapi/cli/cli.py | 43 ++++++++++++++------------ src/blueapi/client/client.py | 30 ++++++++++++------ tests/unit_tests/client/test_client.py | 34 ++++++++++---------- tests/unit_tests/test_cli.py | 12 +++---- 4 files changed, 69 insertions(+), 50 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 3f391df9a..7e012dca4 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -6,19 +6,21 @@ from functools import wraps from pathlib import Path from pprint import pprint +from typing import Any import click from bluesky.callbacks.best_effort import BestEffortCallback from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker +from click.core import Context, Parameter from click.exceptions import ClickException +from click.types import ParamType from observability_utils.tracing import setup_tracing -from pydantic import ValidationError from requests.exceptions import ConnectionError from blueapi import __version__, config from blueapi.cli.format import OutputFormat -from blueapi.client.client import BlueapiClient +from blueapi.client.client import BlueapiClient, TaskParameters from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import ( BlueskyRemoteControlError, @@ -34,12 +36,27 @@ from blueapi.log import set_up_logging from blueapi.service.authentication import SessionCacheManager, SessionManager from blueapi.service.model import SourceInfo -from blueapi.worker import ProgressEvent, Task, WorkerEvent +from blueapi.worker import ProgressEvent, WorkerEvent from .scratch import setup_scratch from .updates import CliEventRenderer +class ParametersType(ParamType): + name = "TaskParameters" + + def convert( + self, value: Any, param: Parameter | None, ctx: Context | None + ) -> TaskParameters: + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError as jde: + self.fail(f"Parameters are not valid JSON: {jde}") + else: + return super().convert(value, param, ctx) + + @click.group( invoke_without_command=True, context_settings={"auto_envvar_prefix": "BLUEAPI"} ) @@ -220,7 +237,7 @@ def on_event( @controller.command(name="run") @click.argument("name", type=str) -@click.argument("parameters", type=str, required=False) +@click.argument("parameters", type=ParametersType(), default={}, required=False) @click.option( "--foreground/--background", "--fg/--bg", type=bool, is_flag=True, default=True ) @@ -236,25 +253,13 @@ def on_event( def run_plan( obj: dict, name: str, - parameters: str | None, timeout: float | None, foreground: bool, + parameters: TaskParameters, ) -> None: """Run a plan with parameters""" client: BlueapiClient = obj["client"] - parameters = parameters or "{}" - try: - parsed_params = json.loads(parameters) if isinstance(parameters, str) else {} - except json.JSONDecodeError as jde: - raise ClickException(f"Parameters are not valid JSON: {jde}") from jde - - try: - task = Task(name=name, params=parsed_params) - except ValidationError as ve: - ip = InvalidParameters.from_validation_error(ve) - raise ClickException(ip.message()) from ip - try: if foreground: progress_bar = CliEventRenderer() @@ -266,12 +271,12 @@ def on_event(event: AnyEvent) -> None: elif isinstance(event, DataEvent): callback(event.name, event.doc) - resp = client.run_task(task, on_event=on_event) + resp = client.run_task(name, parameters, on_event=on_event) if resp.task_status is not None and not resp.task_status.task_failed: print("Plan Succeeded") else: - server_task = client.create_and_start_task(task) + server_task = client.create_and_start_task(name, parameters) click.echo(server_task.task_id) except config.MissingStompConfiguration as mse: raise ClickException(*mse.args) from mse diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 113b14686..614db1ada 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -1,5 +1,6 @@ import time from concurrent.futures import Future +from typing import Any from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -7,6 +8,7 @@ get_tracer, start_as_current_span, ) +from pydantic import ValidationError from blueapi.config import ApplicationConfig, MissingStompConfiguration from blueapi.core.bluesky_types import DataEvent @@ -28,10 +30,12 @@ from blueapi.worker.event import ProgressEvent, TaskStatus from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent -from .rest import BlueapiRestClient, BlueskyRemoteControlError +from .rest import BlueapiRestClient, BlueskyRemoteControlError, InvalidParameters TRACER = get_tracer("client") +TaskParameters = dict[str, Any] + class BlueapiClient: """Unified client for controlling blueapi""" @@ -194,10 +198,12 @@ def get_active_task(self) -> WorkerTask: return self._rest.get_active_task() - @start_as_current_span(TRACER, "task", "timeout") + @start_as_current_span(TRACER, "name", "parameters", "timeout") def run_task( self, - task: Task, + name: str, + parameters: TaskParameters = {}, + *, on_event: OnAnyEvent | None = None, timeout: float | None = None, ) -> WorkerEvent: @@ -220,7 +226,7 @@ def run_task( "Stomp configuration required to run plans is missing or disabled" ) - task_response = self.create_task(task) + task_response = self.create_task(name, parameters) task_id = task_response.task_id complete: Future[WorkerEvent] = Future() @@ -257,8 +263,10 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: self.start_task(WorkerTask(task_id=task_id)) return complete.result(timeout=timeout) - @start_as_current_span(TRACER, "task") - def create_and_start_task(self, task: Task) -> TaskResponse: + @start_as_current_span(TRACER, "name", "parameters") + def create_and_start_task( + self, name: str, parameters: TaskParameters = {} + ) -> TaskResponse: """ Create a new task and instruct the worker to start it immediately. @@ -270,7 +278,7 @@ def create_and_start_task(self, task: Task) -> TaskResponse: TaskResponse: Acknowledgement of request """ - response = self.create_task(task) + response = self.create_task(name, parameters) worker_response = self.start_task(WorkerTask(task_id=response.task_id)) if worker_response.task_id == response.task_id: return response @@ -280,8 +288,8 @@ def create_and_start_task(self, task: Task) -> TaskResponse: f"but {worker_response.task_id} was started instead" ) - @start_as_current_span(TRACER, "task") - def create_task(self, task: Task) -> TaskResponse: + @start_as_current_span(TRACER, "name", "parameters") + def create_task(self, name: str, parameters: TaskParameters = {}) -> TaskResponse: """ Create a new task, does not start execution @@ -292,6 +300,10 @@ def create_task(self, task: Task) -> TaskResponse: TaskResponse: Acknowledgement of request """ + try: + task = Task(name=name, params=parameters) + except ValidationError as ve: + raise InvalidParameters.from_validation_error(ve) return self._rest.create_task(task) @start_as_current_span(TRACER) diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index 28c764488..2d343cbc9 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -171,7 +171,7 @@ def test_create_task( client: BlueapiClient, mock_rest: Mock, ): - client.create_task(task=Task(name="foo")) + client.create_task(name="foo") mock_rest.create_task.assert_called_once_with(Task(name="foo")) @@ -179,7 +179,7 @@ def test_create_task_does_not_start_task( client: BlueapiClient, mock_rest: Mock, ): - client.create_task(task=Task(name="foo")) + client.create_task(name="foo") mock_rest.update_worker_task.assert_not_called() @@ -218,7 +218,7 @@ def test_create_and_start_task_calls_both_creating_and_starting_endpoints( ): mock_rest.create_task.return_value = TaskResponse(task_id="baz") mock_rest.update_worker_task.return_value = TaskResponse(task_id="baz") - client.create_and_start_task(Task(name="baz")) + client.create_and_start_task(name="baz") mock_rest.create_task.assert_called_once_with(Task(name="baz")) mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="baz")) @@ -229,7 +229,7 @@ def test_create_and_start_task_fails_if_task_creation_fails( ): mock_rest.create_task.side_effect = BlueskyRemoteControlError("No can do") with pytest.raises(BlueskyRemoteControlError): - client.create_and_start_task(Task(name="baz")) + client.create_and_start_task(name="baz") def test_create_and_start_task_fails_if_task_id_is_wrong( @@ -239,7 +239,7 @@ def test_create_and_start_task_fails_if_task_id_is_wrong( mock_rest.create_task.return_value = TaskResponse(task_id="baz") mock_rest.update_worker_task.return_value = TaskResponse(task_id="bar") with pytest.raises(BlueskyRemoteControlError): - client.create_and_start_task(Task(name="baz")) + client.create_and_start_task(name="baz") def test_create_and_start_task_fails_if_task_start_fails( @@ -249,7 +249,7 @@ def test_create_and_start_task_fails_if_task_start_fails( mock_rest.create_task.return_value = TaskResponse(task_id="baz") mock_rest.update_worker_task.side_effect = BlueskyRemoteControlError("No can do") with pytest.raises(BlueskyRemoteControlError): - client.create_and_start_task(Task(name="baz")) + client.create_and_start_task(name="baz") def test_get_environment(client: BlueapiClient): @@ -384,7 +384,7 @@ def test_cannot_run_task_without_message_bus(client: BlueapiClient): MissingStompConfiguration, match="Stomp configuration required to run plans is missing or disabled", ): - client.run_task(Task(name="foo")) + client.run_task(name="foo") def test_run_task_sets_up_control( @@ -398,7 +398,7 @@ def test_run_task_sets_up_control( ctx.correlation_id = "foo" mock_events.subscribe_to_all_events = lambda on_event: on_event(COMPLETE_EVENT, ctx) - client_with_events.run_task(Task(name="foo")) + client_with_events.run_task(name="foo") mock_rest.create_task.assert_called_once_with(Task(name="foo")) mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="foo")) @@ -417,7 +417,7 @@ def test_run_task_fails_on_failing_event( on_event = Mock() with pytest.raises(BlueskyStreamingError): - client_with_events.run_task(Task(name="foo"), on_event=on_event) + client_with_events.run_task(name="foo", on_event=on_event) on_event.assert_called_with(FAILED_EVENT) @@ -456,7 +456,7 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): mock_events.subscribe_to_all_events = callback # type: ignore mock_on_event = Mock() - client_with_events.run_task(Task(name="foo"), on_event=mock_on_event) + client_with_events.run_task(name="foo", on_event=mock_on_event) assert mock_on_event.mock_calls == [call(test_event), call(COMPLETE_EVENT)] @@ -495,7 +495,7 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): mock_events.subscribe_to_all_events = callback mock_on_event = Mock() - client_with_events.run_task(Task(name="foo"), on_event=mock_on_event) + client_with_events.run_task(name="foo", on_event=mock_on_event) mock_on_event.assert_called_once_with(COMPLETE_EVENT) @@ -543,8 +543,8 @@ def test_create_task_span_ok( client: BlueapiClient, mock_rest: Mock, ): - with asserting_span_exporter(exporter, "create_task", "task"): - client.create_task(task=Task(name="foo")) + with asserting_span_exporter(exporter, "create_task", "name", "parameters"): + client.create_task(name="foo") def test_clear_task_span_ok( @@ -579,8 +579,10 @@ def test_create_and_start_task_span_ok( ): mock_rest.create_task.return_value = TaskResponse(task_id="baz") mock_rest.update_worker_task.return_value = TaskResponse(task_id="baz") - with asserting_span_exporter(exporter, "create_and_start_task", "task"): - client.create_and_start_task(Task(name="baz")) + with asserting_span_exporter( + exporter, "create_and_start_task", "name", "parameters" + ): + client.create_and_start_task(name="baz") def test_get_environment_span_ok( @@ -644,4 +646,4 @@ def test_cannot_run_task_span_ok( match="Stomp configuration required to run plans is missing or disabled", ): with asserting_span_exporter(exporter, "grun_task"): - client.run_task(Task(name="foo")) + client.run_task(name="foo") diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 5c26b159a..26d1f4695 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -586,13 +586,13 @@ def test_error_handling(exception, error_message, runner: CliRunner): @pytest.mark.parametrize( - "params, error", + "params, error, code", [ - ("{", "Parameters are not valid JSON"), - ("[]", ""), + ("{", "Invalid value for '[PARAMETERS]'", 2), + ("[]", "Incorrect parameters supplied", 1), ], ) -def test_run_task_parsing_errors(params: str, error: str, runner: CliRunner): +def test_run_task_parsing_errors(params: str, error: str, code: int, runner: CliRunner): result = runner.invoke( main, [ @@ -604,8 +604,8 @@ def test_run_task_parsing_errors(params: str, error: str, runner: CliRunner): params, ], ) - assert result.stderr.startswith("Error: " + error) - assert result.exit_code == 1 + assert ("Error: " + error) in result.stderr + assert result.exit_code == code def test_device_output_formatting(): From 69de92777d07fabd8c6e00bfb24ea60174d33cfd Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 23 Jun 2025 10:05:49 +0100 Subject: [PATCH 2/5] Lints, docs and system tests --- src/blueapi/cli/cli.py | 13 ++++- src/blueapi/client/client.py | 12 ++-- tests/system_tests/test_blueapi_system.py | 68 +++++++++++------------ 3 files changed, 50 insertions(+), 43 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 7e012dca4..7ef39776a 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -6,7 +6,7 @@ from functools import wraps from pathlib import Path from pprint import pprint -from typing import Any +from typing import Any, cast import click from bluesky.callbacks.best_effort import BestEffortCallback @@ -43,6 +43,8 @@ class ParametersType(ParamType): + """CLI input parameter to accept a JSON object as an argument""" + name = "TaskParameters" def convert( @@ -50,11 +52,16 @@ def convert( ) -> TaskParameters: if isinstance(value, str): try: - return json.loads(value) + params = json.loads(value) + if not isinstance(params, dict) or any( + not isinstance(k, str) for k in params + ): + self.fail("Parameters must be a JSON object") + return cast(TaskParameters, params) except json.JSONDecodeError as jde: self.fail(f"Parameters are not valid JSON: {jde}") else: - return super().convert(value, param, ctx) + return cast(TaskParameters, super().convert(value, param, ctx)) @click.group( diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 614db1ada..27302c405 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -202,7 +202,7 @@ def get_active_task(self) -> WorkerTask: def run_task( self, name: str, - parameters: TaskParameters = {}, + parameters: TaskParameters | None = None, *, on_event: OnAnyEvent | None = None, timeout: float | None = None, @@ -226,7 +226,7 @@ def run_task( "Stomp configuration required to run plans is missing or disabled" ) - task_response = self.create_task(name, parameters) + task_response = self.create_task(name, parameters or {}) task_id = task_response.task_id complete: Future[WorkerEvent] = Future() @@ -265,7 +265,7 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: @start_as_current_span(TRACER, "name", "parameters") def create_and_start_task( - self, name: str, parameters: TaskParameters = {} + self, name: str, parameters: TaskParameters | None = None ) -> TaskResponse: """ Create a new task and instruct the worker to start it @@ -278,7 +278,7 @@ def create_and_start_task( TaskResponse: Acknowledgement of request """ - response = self.create_task(name, parameters) + response = self.create_task(name, parameters or {}) worker_response = self.start_task(WorkerTask(task_id=response.task_id)) if worker_response.task_id == response.task_id: return response @@ -289,7 +289,7 @@ def create_and_start_task( ) @start_as_current_span(TRACER, "name", "parameters") - def create_task(self, name: str, parameters: TaskParameters = {}) -> TaskResponse: + def create_task(self, name: str, parameters: TaskParameters) -> TaskResponse: """ Create a new task, does not start execution @@ -303,7 +303,7 @@ def create_task(self, name: str, parameters: TaskParameters = {}) -> TaskRespons try: task = Task(name=name, params=parameters) except ValidationError as ve: - raise InvalidParameters.from_validation_error(ve) + raise InvalidParameters.from_validation_error(ve) from ve return self._rest.create_task(task) @start_as_current_span(TRACER) diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index f60fe07a5..f0762d6be 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -1,6 +1,7 @@ import inspect import time from pathlib import Path +from typing import Any import pytest from bluesky_stomp.models import BasicAuthentication @@ -26,11 +27,10 @@ WorkerTask, ) from blueapi.worker.event import TaskStatus, WorkerEvent, WorkerState -from blueapi.worker.task import Task from blueapi.worker.task_worker import TrackableTask -_SIMPLE_TASK = Task(name="sleep", params={"time": 0.0}) -_LONG_TASK = Task(name="sleep", params={"time": 1.0}) +_SIMPLE_TASK = ("sleep", {"time": 0.0}) +_LONG_TASK = ("sleep", {"time": 1.0}) _DATA_PATH = Path(__file__).parent @@ -182,19 +182,19 @@ def test_get_non_existent_device(client: BlueapiClient): def test_create_task_and_delete_task_by_id(client: BlueapiClient): - create_task = client.create_task(_SIMPLE_TASK) + create_task = client.create_task(*_SIMPLE_TASK) client.clear_task(create_task.task_id) def test_create_task_validation_error(client: BlueapiClient): with pytest.raises(UnknownPlan): - client.create_task(Task(name="Not-exists", params={"Not-exists": 0.0})) + client.create_task("Not-exists", {"Not-exists": 0.0}) def test_get_all_tasks(client: BlueapiClient): created_tasks: list[TaskResponse] = [] for task in [_SIMPLE_TASK, _LONG_TASK]: - created_task = client.create_task(task) + created_task = client.create_task(*task) created_tasks.append(created_task) task_ids = [task.task_id for task in created_tasks] @@ -208,7 +208,7 @@ def test_get_all_tasks(client: BlueapiClient): def test_get_task_by_id(client: BlueapiClient): - created_task = client.create_task(_SIMPLE_TASK) + created_task = client.create_task(*_SIMPLE_TASK) get_task = client.get_task(created_task.task_id) assert ( @@ -232,7 +232,7 @@ def test_delete_non_existent_task(client: BlueapiClient): def test_put_worker_task(client: BlueapiClient): - created_task = client.create_task(_SIMPLE_TASK) + created_task = client.create_task(*_SIMPLE_TASK) client.start_task(WorkerTask(task_id=created_task.task_id)) active_task = client.get_active_task() assert active_task.task_id == created_task.task_id @@ -240,8 +240,8 @@ def test_put_worker_task(client: BlueapiClient): def test_put_worker_task_fails_if_not_idle(client: BlueapiClient): - small_task = client.create_task(_SIMPLE_TASK) - long_task = client.create_task(_LONG_TASK) + small_task = client.create_task(*_SIMPLE_TASK) + long_task = client.create_task(*_LONG_TASK) client.start_task(WorkerTask(task_id=long_task.task_id)) active_task = client.get_active_task() @@ -269,8 +269,8 @@ def test_set_state_transition_error(client: BlueapiClient): def test_get_task_by_status(client: BlueapiClient): - task_1 = client.create_task(_SIMPLE_TASK) - task_2 = client.create_task(_SIMPLE_TASK) + task_1 = client.create_task(*_SIMPLE_TASK) + task_2 = client.create_task(*_SIMPLE_TASK) task_by_pending = client.get_all_tasks() # https://github.com/DiamondLightSource/blueapi/issues/680 # task_by_pending = client.get_tasks_by_status(TaskStatusEnum.PENDING) @@ -305,7 +305,7 @@ def test_progress_with_stomp(client_with_stomp: BlueapiClient): def on_event(event: AnyEvent): all_events.append(event) - client_with_stomp.run_task(_SIMPLE_TASK, on_event=on_event) + client_with_stomp.run_task(*_SIMPLE_TASK, on_event=on_event) assert isinstance(all_events[0], WorkerEvent) and all_events[0].task_status task_id = all_events[0].task_status.task_id assert all_events == [ @@ -350,11 +350,11 @@ def test_delete_current_environment(client: BlueapiClient): @pytest.mark.parametrize( - "task", + "plan,params", [ - Task( - name="count", - params={ + ( + "count", + { "detectors": [ "image_det", "current_det", @@ -362,9 +362,9 @@ def test_delete_current_environment(client: BlueapiClient): "num": 5, }, ), - Task( - name="spec_scan", - params={ + ( + "spec_scan", + { "detectors": [ "image_det", "current_det", @@ -372,34 +372,34 @@ def test_delete_current_environment(client: BlueapiClient): "spec": Line("x", 0.0, 10.0, 2) * Line("y", 5.0, 15.0, 3), }, ), - Task( - name="set_absolute", - params={ + ( + "set_absolute", + { "movable": "dynamic_motor", "value": "bar", }, ), - Task( - name="motor_plan", - params={ + ( + "motor_plan", + { "motor": "movable_motor", }, ), - Task( - name="motor_plan", - params={ + ( + "motor_plan", + { "motor": "dynamic_motor", }, ), - Task( - name="dataclass_motor_plan", - params={ + ( + "dataclass_motor_plan", + { "motor": "data_class_motor", }, ), ], ) -def test_plan_runs(client_with_stomp: BlueapiClient, task: Task): - final_event = client_with_stomp.run_task(task) +def test_plan_runs(client_with_stomp: BlueapiClient, plan: str, params: dict[str, Any]): + final_event = client_with_stomp.run_task(plan, params) assert final_event.is_complete() and not final_event.is_error() assert final_event.state is WorkerState.IDLE From 41771574b77d6b9294486539dcef2042c7c09572 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 23 Jun 2025 10:51:47 +0100 Subject: [PATCH 3/5] Fix test refactoring --- src/blueapi/client/client.py | 6 ++++-- tests/unit_tests/test_cli.py | 12 ++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 27302c405..cedf86329 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -289,7 +289,9 @@ def create_and_start_task( ) @start_as_current_span(TRACER, "name", "parameters") - def create_task(self, name: str, parameters: TaskParameters) -> TaskResponse: + def create_task( + self, name: str, parameters: TaskParameters | None = None + ) -> TaskResponse: """ Create a new task, does not start execution @@ -301,7 +303,7 @@ def create_task(self, name: str, parameters: TaskParameters) -> TaskResponse: """ try: - task = Task(name=name, params=parameters) + task = Task(name=name, params=parameters or {}) except ValidationError as ve: raise InvalidParameters.from_validation_error(ve) from ve return self._rest.create_task(task) diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 26d1f4695..60bd4c31d 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -586,13 +586,13 @@ def test_error_handling(exception, error_message, runner: CliRunner): @pytest.mark.parametrize( - "params, error, code", + "params", [ - ("{", "Invalid value for '[PARAMETERS]'", 2), - ("[]", "Incorrect parameters supplied", 1), + "{", + "[]", ], ) -def test_run_task_parsing_errors(params: str, error: str, code: int, runner: CliRunner): +def test_run_task_parsing_errors(params: str, runner: CliRunner): result = runner.invoke( main, [ @@ -604,8 +604,8 @@ def test_run_task_parsing_errors(params: str, error: str, code: int, runner: Cli params, ], ) - assert ("Error: " + error) in result.stderr - assert result.exit_code == code + assert "Error: Invalid value for '[PARAMETERS]'" in result.stderr + assert result.exit_code == 2 def test_device_output_formatting(): From b6d00861cfef3ba67ea9ff4f780abafb33802297 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 23 Jun 2025 12:21:34 +0100 Subject: [PATCH 4/5] Missing test coverage --- src/blueapi/client/client.py | 8 ++------ tests/unit_tests/test_cli.py | 8 +++++++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index cedf86329..e275b0f6c 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -8,7 +8,6 @@ get_tracer, start_as_current_span, ) -from pydantic import ValidationError from blueapi.config import ApplicationConfig, MissingStompConfiguration from blueapi.core.bluesky_types import DataEvent @@ -30,7 +29,7 @@ from blueapi.worker.event import ProgressEvent, TaskStatus from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent -from .rest import BlueapiRestClient, BlueskyRemoteControlError, InvalidParameters +from .rest import BlueapiRestClient, BlueskyRemoteControlError TRACER = get_tracer("client") @@ -302,10 +301,7 @@ def create_task( TaskResponse: Acknowledgement of request """ - try: - task = Task(name=name, params=parameters or {}) - except ValidationError as ve: - raise InvalidParameters.from_validation_error(ve) from ve + task = Task(name=name, params=parameters or {}) return self._rest.create_task(task) @start_as_current_span(TRACER) diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 60bd4c31d..b6e794716 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -24,7 +24,7 @@ from stomp.connect import StompConnection11 as Connection from blueapi import __version__ -from blueapi.cli.cli import main +from blueapi.cli.cli import ParametersType, main from blueapi.cli.format import OutputFormat, fmt_dict from blueapi.client.event_bus import BlueskyStreamingError from blueapi.client.rest import ( @@ -1164,3 +1164,9 @@ def test_python_env_output_formatting(): """) _assert_matching_formatting(OutputFormat.FULL, empty_python_env, full) + + +@pytest.mark.parametrize("value,result", [({}, {}), ("{}", {}), (None, None)]) +def test_task_parameter_type(value, result): + t = ParametersType() + assert t.convert(value, None, None) == result From 9ad70fcbe7e573e0ad9aa6f62dfda48aa455d960 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 24 Jun 2025 14:07:15 +0100 Subject: [PATCH 5/5] Use type guard instead of cast --- src/blueapi/cli/cli.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 7ef39776a..16a38f796 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -6,7 +6,7 @@ from functools import wraps from pathlib import Path from pprint import pprint -from typing import Any, cast +from typing import Any, TypeGuard import click from bluesky.callbacks.best_effort import BestEffortCallback @@ -48,20 +48,25 @@ class ParametersType(ParamType): name = "TaskParameters" def convert( - self, value: Any, param: Parameter | None, ctx: Context | None + self, + value: str | dict[str, Any] | None, + param: Parameter | None, + ctx: Context | None, ) -> TaskParameters: if isinstance(value, str): try: params = json.loads(value) - if not isinstance(params, dict) or any( - not isinstance(k, str) for k in params - ): - self.fail("Parameters must be a JSON object") - return cast(TaskParameters, params) + if is_str_dict(params): + return params + self.fail("Parameters must be a JSON object with string keys") except json.JSONDecodeError as jde: self.fail(f"Parameters are not valid JSON: {jde}") else: - return cast(TaskParameters, super().convert(value, param, ctx)) + return super().convert(value, param, ctx) + + +def is_str_dict(val: Any) -> TypeGuard[TaskParameters]: + return isinstance(val, dict) and all(isinstance(k, str) for k in val) @click.group(