From 171dbd58591f857ff86aac63be6af2a9b05cb572 Mon Sep 17 00:00:00 2001 From: WilliamTakeshi Date: Wed, 21 Jan 2026 14:08:35 +0100 Subject: [PATCH 1/3] feat: add websockets --- dotbot/examples/charging_station.py | 62 ++++++++++++------ dotbot/models.py | 30 ++++++++- dotbot/server.py | 63 +++++++++++++++++++ .../tests/test_experiment_charging_station.py | 59 ++++++++++++++++- dotbot/websocket.py | 20 ++++++ 5 files changed, 211 insertions(+), 23 deletions(-) create mode 100644 dotbot/websocket.py diff --git a/dotbot/examples/charging_station.py b/dotbot/examples/charging_station.py index da9acf56..05f97b7b 100644 --- a/dotbot/examples/charging_station.py +++ b/dotbot/examples/charging_station.py @@ -15,9 +15,12 @@ DotBotMoveRawCommandModel, DotBotRgbLedCommandModel, DotBotWaypoints, + WSRgbLed, + WSWaypoints, ) from dotbot.protocol import ApplicationType -from dotbot.rest import RestClient +from dotbot.rest import RestClient, rest_client +from dotbot.websocket import DotBotWsClient THRESHOLD = 30 # Acceptable distance error to consider a waypoint reached DT = 0.05 # Control loop period (seconds) @@ -40,16 +43,18 @@ async def queue_robots( client: RestClient, + ws: DotBotWsClient, dotbots: List[DotBotModel], params: OrcaParams, ) -> None: sorted_bots = order_bots(dotbots, QUEUE_HEAD_X, QUEUE_HEAD_Y) goals = assign_queue_goals(sorted_bots, QUEUE_HEAD_X, QUEUE_HEAD_Y, QUEUE_SPACING) - await send_to_goal(client, goals, params) + await send_to_goal(client, ws, goals, params) async def charge_robots( client: RestClient, + ws: DotBotWsClient, params: OrcaParams, ) -> None: dotbots = await client.fetch_active_dotbots() @@ -76,7 +81,7 @@ async def charge_robots( "x": PARK_X, "y": PARK_Y + parked_count * PARK_SPACING, } - await send_to_goal(client, goals, params) + await send_to_goal(client, ws, goals, params) if len(remaining) == 0: break @@ -123,6 +128,7 @@ async def disengage_from_charger(client: RestClient, dotbot: DotBotModel): async def send_to_goal( client: RestClient, + ws: DotBotWsClient, goals: Dict[str, dict], params: OrcaParams, ) -> None: @@ -178,11 +184,15 @@ async def send_to_goal( ) ], ) - await client.send_waypoint_command( - address=agent.id, - application=ApplicationType.DotBot, - command=waypoints, + await ws.send( + WSWaypoints( + cmd="waypoints", + address=agent.id, + application=ApplicationType.DotBot, + data=waypoints, + ) ) + await asyncio.sleep(DT) return None @@ -299,22 +309,34 @@ async def main() -> None: url = os.getenv("DOTBOT_CONTROLLER_URL", "localhost") port = os.getenv("DOTBOT_CONTROLLER_PORT", "8000") use_https = os.getenv("DOTBOT_CONTROLLER_USE_HTTPS", False) - client = RestClient(url, port, use_https) - - dotbots = await client.fetch_active_dotbots() + async with rest_client(url, port, use_https) as client: + dotbots = await client.fetch_active_dotbots() - # Cosmetic: all bots are red - for dotbot in dotbots: - await client.send_rgb_led_command( - address=dotbot.address, - command=DotBotRgbLedCommandModel(red=255, green=0, blue=0), - ) + ws = DotBotWsClient(url, port) + await ws.connect() + try: + # Cosmetic: all bots are red + for dotbot in dotbots: + await ws.send( + WSRgbLed( + cmd="rgb_led", + address=dotbot.address, + application=ApplicationType.DotBot, + data=DotBotRgbLedCommandModel( + red=255, + green=0, + blue=0, + ), + ) + ) - # Phase 1: initial queue - await queue_robots(client, dotbots, params) + # Phase 1: initial queue + await queue_robots(client, ws, dotbots, params) - # Phase 2: charging loop - await charge_robots(client, params) + # Phase 2: charging loop + await charge_robots(client, ws, params) + finally: + await ws.close() return None diff --git a/dotbot/models.py b/dotbot/models.py index f05002c2..2a88c145 100644 --- a/dotbot/models.py +++ b/dotbot/models.py @@ -10,7 +10,7 @@ # pylint: disable=too-few-public-methods,no-name-in-module from enum import IntEnum -from typing import Any, List, Optional, Union +from typing import Any, List, Literal, Optional, Union from pydantic import BaseModel @@ -172,3 +172,31 @@ class DotBotModel(BaseModel): position_history: List[Union[DotBotLH2Position, DotBotGPSPosition]] = [] calibrated: bool = False battery: float = 0.0 # Voltage in Volts + + +class WSBase(BaseModel): + cmd: str + address: str + application: ApplicationType + + +class WSRgbLed(WSBase): + cmd: Literal["rgb_led"] + data: DotBotRgbLedCommandModel + + +class WSMoveRaw(WSBase): + cmd: Literal["move_raw"] + data: DotBotMoveRawCommandModel + + +class WSWaypoints(WSBase): + cmd: Literal["waypoints"] + data: DotBotWaypoints + + +WSMessage = Union[ + WSRgbLed, + WSMoveRaw, + WSWaypoints, +] diff --git a/dotbot/server.py b/dotbot/server.py index 0a564c66..200a6019 100644 --- a/dotbot/server.py +++ b/dotbot/server.py @@ -13,6 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response from fastapi.staticfiles import StaticFiles +from pydantic import TypeAdapter, ValidationError from starlette.middleware.base import BaseHTTPMiddleware from dotbot import pydotbot_version @@ -25,6 +26,10 @@ DotBotQueryModel, DotBotRgbLedCommandModel, DotBotWaypoints, + WSMessage, + WSMoveRaw, + WSRgbLed, + WSWaypoints, ) from dotbot.protocol import ( ApplicationType, @@ -40,6 +45,8 @@ "PYDOTBOT_FRONTEND_BASE_URL", "https://dotbots.github.io/PyDotBot" ) +ws_adapter = TypeAdapter(WSMessage) + class ReverseProxyMiddleware(BaseHTTPMiddleware): @@ -98,6 +105,10 @@ async def dotbots_move_raw( if address not in api.controller.dotbots: raise HTTPException(status_code=404, detail="No matching dotbot found") + _dotbots_move_raw(address=address, command=command) + + +def _dotbots_move_raw(address: str, command: DotBotMoveRawCommandModel): payload = PayloadCommandMoveRaw( left_x=command.left_x, left_y=command.left_y, @@ -120,6 +131,10 @@ async def dotbots_rgb_led( if address not in api.controller.dotbots: raise HTTPException(status_code=404, detail="No matching dotbot found") + _dotbots_rgb_led(address=address, command=command) + + +def _dotbots_rgb_led(address: str, command: DotBotRgbLedCommandModel): payload = PayloadCommandRgbLed( red=command.red, green=command.green, blue=command.blue ) @@ -141,6 +156,16 @@ async def dotbots_waypoints( if address not in api.controller.dotbots: raise HTTPException(status_code=404, detail="No matching dotbot found") + await _dotbots_waypoints( + address=address, application=application, waypoints=waypoints + ) + + +async def _dotbots_waypoints( + address: str, + application: int, + waypoints: DotBotWaypoints, +): waypoints_list = waypoints.waypoints if application == ApplicationType.SailBot.value: if api.controller.dotbots[address].gps_position is not None: @@ -236,6 +261,44 @@ async def websocket_endpoint(websocket: WebSocket): api.controller.websockets.remove(websocket) +@api.websocket("/controller/ws/dotbots") +async def ws_dotbots(websocket: WebSocket): + await websocket.accept() + try: + while True: + raw = await websocket.receive_json() + + try: + msg = ws_adapter.validate_python(raw) + except ValidationError as e: + await websocket.send_json( + { + "error": "invalid_message", + "details": e.errors(), + } + ) + continue + if isinstance(msg, WSRgbLed): + _dotbots_rgb_led( + address=msg.address, + command=msg.data, + ) + elif isinstance(msg, WSMoveRaw): + _dotbots_move_raw( + address=msg.address, + command=msg.data, + ) + elif isinstance(msg, WSWaypoints): + await _dotbots_waypoints( + address=msg.address, + application=msg.application, + waypoints=msg.data, + ) + + except WebSocketDisconnect: + LOGGER.debug("WebSocket client disconnected") + + # Mount static files after all routes are defined FRONTEND_DIR = os.path.join(os.path.dirname(__file__), "frontend", "build") api.mount("/PyDotBot", StaticFiles(directory=FRONTEND_DIR, html=True), name="PyDotBot") diff --git a/dotbot/tests/test_experiment_charging_station.py b/dotbot/tests/test_experiment_charging_station.py index ee500320..c9339e0f 100644 --- a/dotbot/tests/test_experiment_charging_station.py +++ b/dotbot/tests/test_experiment_charging_station.py @@ -24,6 +24,7 @@ DotBotRgbLedCommandModel, DotBotStatus, DotBotWaypoints, + WSMessage, ) from dotbot.protocol import ApplicationType @@ -130,6 +131,56 @@ async def send_rgb_led_command( self.rgb_commands.append((address, command)) +class FakeDotBotWsClient: + """ + Fake WebSocket client for testing control logic. + + - Accepts typed WSMessage objects + - Dispatches to FakeRestClient logic + - Records all WS messages for assertions + """ + + def __init__(self, rest_client: FakeRestClient): + self.rest = rest_client + self.sent_messages: list[WSMessage] = [] + self.connected = False + + async def connect(self): + self.connected = True + + async def close(self): + self.connected = False + + async def send(self, msg: WSMessage): + if not self.connected: + raise RuntimeError("FakeDotBotWsClient is not connected") + + self.sent_messages.append(msg) + + if msg.cmd == "rgb_led": + await self.rest.send_rgb_led_command( + address=msg.address, + command=msg.data, + ) + + elif msg.cmd == "move_raw": + await self.rest.send_move_raw_command( + address=msg.address, + application=msg.application, + command=msg.data, + ) + + elif msg.cmd == "waypoints": + await self.rest.send_waypoint_command( + address=msg.address, + application=msg.application, + command=msg.data, + ) + + else: + raise ValueError(f"Unknown WS command: {msg.cmd}") + + def fake_bot(address: str, x: float, y: float) -> DotBotModel: return DotBotModel( address=address, @@ -151,9 +202,11 @@ async def test_queue_robots_converges_to_queue_positions(_): ] client = FakeRestClient(bots) + ws = FakeDotBotWsClient(client) + await ws.connect() params = OrcaParams(time_horizon=5 * DT, time_step=DT) - await queue_robots(client, bots, params) + await queue_robots(client, ws, bots, params) # Bots should be ordered A, B, C along the queue expected = { @@ -184,9 +237,11 @@ async def test_charge_robots_moves_all_bots_to_parking(_): ] client = FakeRestClient(bots) + ws = FakeDotBotWsClient(client) + await ws.connect() params = OrcaParams(time_horizon=5 * DT, time_step=DT) - await charge_robots(client, params) + await charge_robots(client, ws, params) # --- Assertions: all bots parked --- # Bots should be ordered A, B, C along the park slots diff --git a/dotbot/websocket.py b/dotbot/websocket.py new file mode 100644 index 00000000..f9920008 --- /dev/null +++ b/dotbot/websocket.py @@ -0,0 +1,20 @@ +from dotbot.models import WSMessage + + +class DotBotWsClient: + def __init__(self, host, port): + self.url = f"ws://{host}:{port}/controller/ws/dotbots" + self.ws = None + + async def connect(self): + import websockets + + self.ws = await websockets.connect(self.url) + + async def close(self): + await self.ws.close() + + async def send(self, msg: WSMessage): + if not self.ws: + raise RuntimeError("WebSocket not connected") + await self.ws.send(msg.model_dump_json()) From 63c4bfb5f3f71ac1bb8cca6bfed14e440fa7e2e9 Mon Sep 17 00:00:00 2001 From: WilliamTakeshi Date: Fri, 23 Jan 2026 11:21:16 +0100 Subject: [PATCH 2/3] test: add case for ws --- dotbot/server.py | 5 ++ dotbot/tests/test_server.py | 129 ++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+) diff --git a/dotbot/server.py b/dotbot/server.py index 200a6019..269c5be0 100644 --- a/dotbot/server.py +++ b/dotbot/server.py @@ -278,6 +278,11 @@ async def ws_dotbots(websocket: WebSocket): } ) continue + + if msg.address not in api.controller.dotbots: + # ignore messages where address doesn't exist + continue + if isinstance(msg, WSRgbLed): _dotbots_rgb_led( address=msg.address, diff --git a/dotbot/tests/test_server.py b/dotbot/tests/test_server.py index 4fbb61fa..f03419e7 100644 --- a/dotbot/tests/test_server.py +++ b/dotbot/tests/test_server.py @@ -12,6 +12,10 @@ DotBotModel, DotBotMoveRawCommandModel, DotBotRgbLedCommandModel, + DotBotWaypoints, + WSMoveRaw, + WSRgbLed, + WSWaypoints, ) from dotbot.protocol import ( ApplicationType, @@ -609,3 +613,128 @@ def mock_async_client(*args, **kwargs): # serve.side_effect = asyncio.exceptions.CancelledError() # await web(None) # assert "Web server cancelled" in caplog.text + + +@pytest.mark.parametrize( + "dotbots,ws_message,expected_payload,should_call", + [ + pytest.param( + # ---- RGB LED (valid) ---- + { + "4242": DotBotModel( + address="4242", + application=ApplicationType.DotBot, + swarm="0000", + last_seen=123.4, + ) + }, + WSRgbLed( + cmd="rgb_led", + address="4242", + application=ApplicationType.DotBot, + data=DotBotRgbLedCommandModel( + red=255, + green=0, + blue=128, + ), + ), + PayloadCommandRgbLed(red=255, green=0, blue=128), + True, + id="rgb_led_valid", + ), + pytest.param( + # ---- WAYPOINTS (valid) ---- + { + "4242": DotBotModel( + address="4242", + application=ApplicationType.DotBot, + swarm="0000", + last_seen=123.4, + ) + }, + WSWaypoints( + cmd="waypoints", + address="4242", + application=ApplicationType.DotBot, + data=DotBotWaypoints( + threshold=10, + waypoints=[DotBotLH2Position(x=0.5, y=0.1, z=0)], + ), + ), + PayloadLH2Waypoints( + threshold=10, + count=1, + waypoints=[PayloadLH2Location(pos_x=500000, pos_y=100000, pos_z=0)], + ), + True, + id="waypoints_valid", + ), + pytest.param( + # ---- MOVE_RAW (valid) ---- + { + "4242": DotBotModel( + address="4242", + application=ApplicationType.DotBot, + swarm="0000", + last_seen=123.4, + ) + }, + WSMoveRaw( + cmd="move_raw", + address="4242", + application=ApplicationType.DotBot, + data=DotBotMoveRawCommandModel( + left_x=0, + left_y=100, + right_x=0, + right_y=100, + ), + ), + PayloadCommandMoveRaw( + left_x=0, + left_y=100, + right_x=0, + right_y=100, + ), + True, + id="move_raw_valid", + ), + pytest.param( + # ---- UNKNOWN ADDRESS (ignored) ---- + {}, + WSRgbLed( + cmd="rgb_led", + address="4242", + application=ApplicationType.DotBot, + data=DotBotRgbLedCommandModel( + red=255, + green=0, + blue=128, + ), + ), + None, + False, + id="address_not_found", + ), + ], +) +def test_ws_dotbots_commands( + dotbots, + ws_message, + expected_payload, + should_call, +): + api.controller.dotbots = dotbots + + with TestClient(api).websocket_connect("/controller/ws/dotbots") as ws: + ws.send_json(ws_message.model_dump()) + + if should_call: + api.controller.send_payload.assert_called() + if expected_payload is not None: + api.controller.send_payload.assert_called_with( + int(ws_message.address, 16), + expected_payload, + ) + else: + api.controller.send_payload.assert_not_called() From f3501cd4ce9948dbd2c74df684cecf399c9c3ae3 Mon Sep 17 00:00:00 2001 From: WilliamTakeshi Date: Fri, 23 Jan 2026 14:18:45 +0100 Subject: [PATCH 3/3] test: case for invalid messages on WS --- dotbot/tests/test_server.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/dotbot/tests/test_server.py b/dotbot/tests/test_server.py index f03419e7..fd46fba7 100644 --- a/dotbot/tests/test_server.py +++ b/dotbot/tests/test_server.py @@ -738,3 +738,37 @@ def test_ws_dotbots_commands( ) else: api.controller.send_payload.assert_not_called() + + +@pytest.mark.asyncio +def test_ws_invalid_message_validation_error(): + api.controller.dotbots = { + "4242": DotBotModel( + address="4242", + application=ApplicationType.DotBot, + swarm="0000", + last_seen=123.4, + ) + } + + invalid_message = { + # cmd doesn't match with data + "cmd": "waypoints", + "address": "4242", + "data": { + "red": 255, + "green": 0, + "blue": 0, + }, + } + + with TestClient(api).websocket_connect("/controller/ws/dotbots") as ws: + ws.send_json(invalid_message) + + response = ws.receive_json() + + assert response["error"] == "invalid_message" + assert "details" in response + assert isinstance(response["details"], list) + + api.controller.send_payload.assert_not_called()